1# See README.md in this directory for more guidance 2 3# *********NB: _cast_* operators are DEPRECATED and will be removed 4# eventually. These were previously used before TorchScript IR supported 5# representing ScalarType's. They are now superseded by usage of 6# `aten::to()`. The ops remain here for backward compatibility purposes. 7 8# DEPRECATED. DO NOT USE 9- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor 10 variants: function 11 12# DEPRECATED. DO NOT USE 13- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor 14 variants: function 15 16# DEPRECATED. DO NOT USE 17- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor 18 variants: function 19 20# DEPRECATED. DO NOT USE 21- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor 22 variants: function 23 24# DEPRECATED. DO NOT USE 25- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor 26 variants: function 27 28# DEPRECATED. DO NOT USE 29- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor 30 variants: function 31 32# DEPRECATED. DO NOT USE 33- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor 34 variants: function 35 36# DEPRECATED. DO NOT USE 37- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor 38 variants: function 39 40# Computes the gradient of current tensor w.r.t. graph leaves. 41- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () 42 manual_cpp_binding: True 43 variants: method 44 45# DEPRECATED. Sets the tensor data held by this `Variable` to be the same as 46# `new_data`. It requires that `new_data` and `Variable` have compatible tensor 47# type, by checking `_has_compatible_shallow_copy_type(this, new_data)`. 48# 49# This function is deprecated because it doesn't really make sense in a world 50# where Variables *are* Tensors (as opposed to them containing tensors, which 51# is what the previous interpretation was.) 52- func: set_data(Tensor(a!) self, Tensor new_data) -> () 53 manual_cpp_binding: True 54 variants: method 55 56- func: data(Tensor self) -> Tensor 57 manual_cpp_binding: True 58 variants: method 59 60# True if this `Variable` is a leaf and thus does not have a `grad_fn`. 61- func: is_leaf(Tensor self) -> bool 62 manual_cpp_binding: True 63 variants: method 64 65# Returns the output index of this variable from the forward operation that 66# produced it. Conversely, it returns the input index of the gradient `Node` to 67# which this `Variable` is connected (because in the gradient computation, 68# inputs and outputs switch meaning). For example: 69# 70# y0, y1, y2 = f(x) 71# assert y0.output_nr == 0 72# assert y1.output_nr == 1 73# assert y2.output_nr == 2 74# 75- func: output_nr(Tensor self) -> int 76 manual_cpp_binding: True 77 variants: method 78 79- func: _version(Tensor self) -> int 80 manual_cpp_binding: True 81 variants: method 82 83- func: requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!) 84 manual_cpp_binding: True 85 variants: method 86 87# Enables .grad attribute for non-leaf Tensors. 88- func: retain_grad(Tensor(a!) self) -> () 89 manual_cpp_binding: True 90 variants: method 91 92- func: retains_grad(Tensor self) -> bool 93 manual_cpp_binding: True 94 variants: method 95 96- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a) 97 variants: method 98 dispatch: 99 CompositeExplicitAutograd: _fw_primal 100 101- func: _make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a) 102 variants: function 103 dispatch: 104 CompositeExplicitAutograd: _make_dual 105 106- func: _unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent) 107 variants: function 108 109# NOTE: [_new_zeros_with_same_feature_meta] 110# This function creates a new tensor with the layout and TensorOptions 111# of `other` but also takes into account the batch dimensions of `self` 112# 113# This function has a couple extra constraints because it is also used for `jvp` 114# in functorch. 115# - is used for forward AD because there is the restriction 116# that the primal and tangent must have the same layout 117# - We cannot assume that `self` and `other` have the same sizes or even dim 118# because in the inplace over view case, `other` is the base tensor, and 119# `self` is the forward grad with respect to the view, which can have an 120# entirely different shape 121# - takes the number of batch dims for `self` because we also handle 122# some batching logic. We handle that here instead of a batching rule because 123# we'd like to avoid calling as_strided in the batching rule (as to enable 124# nested vmap in functorch). 125# - needs to be CompositeExplicitAutograd for jvp support in functorch. 126# functorch currently relies on TensorWrapper which does not have storage 127# CompositeExplicitAutograd makes sure the TensorWrapper is unwrapped. 128# - this function may eventually take on another int argument to store the 129# the number of batch dims for other once we support that use case 130- func: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor 131 variants: function 132 dispatch: 133 CompositeExplicitAutograd: _new_zeros_with_same_feature_meta 134 autogen: _new_zeros_with_same_feature_meta.out 135 136# This function compares the storage numel of self with that of other, where 137# storage numel is computed as: `other.storage().nbytes() / other.itemsize()`. 138# We create this function for composite compliance purposes. The batching rule 139# always returns true because vmapped as_strided does not support accessing 140# storage locations not indexable by the input tensor. 141# See the note above for more information. 142- func: _has_same_storage_numel(Tensor self, Tensor other) -> bool 143 variants: function 144 dispatch: 145 CompositeExplicitAutograd: _has_same_storage_numel 146 147- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) 148 variants: method 149 tags: inplace_view 150 151- func: rename(Tensor(a) self, Dimname[]? names) -> Tensor(a) 152 variants: method 153 154- func: align_to(Tensor(a) self, Dimname[] names) -> Tensor(a) 155 variants: method 156 157- func: align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a) 158 variants: method 159 160- func: align_as(Tensor self, Tensor other) -> Tensor 161 variants: method 162 163- func: align_tensors(Tensor[] tensors) -> Tensor[] 164 165# Not assert because it's a keyword; not Assert because FX already 166# took that syntax 167# TODO: need to specify this is side-effectful somehow 168- func: _assert_async(Tensor self) -> () 169 dispatch: 170 CPU: _assert_async_cpu 171 CUDA: _assert_async_cuda 172 173- func: _assert_async.msg(Tensor self, str assert_msg) -> () 174 dispatch: 175 CPU: _assert_async_msg_cpu 176 CUDA: _assert_async_msg_cuda 177 178- func: _assert_scalar(Scalar self, str assert_msg) -> () 179 dispatch: 180 CompositeExplicitAutograd: _assert_scalar 181 182- func: _functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor 183 dispatch: 184 CompositeExplicitAutograd: _functional_assert_scalar 185 186- func: _functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor 187 dispatch: 188 CPU: _functional_assert_async_msg_cpu 189 190- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None) -> () 191 192- func: _print(str s) -> () 193 dispatch: 194 CompositeExplicitAutograd: _print 195 196- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> () 197 dispatch: 198 CompositeExplicitAutograd: sym_constrain_range 199 200- func: sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> () 201 dispatch: 202 CompositeExplicitAutograd: sym_constrain_range_for_size 203 204- func: _functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor 205 dispatch: 206 CompositeExplicitAutograd: _functional_sym_constrain_range 207 208- func: _functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor 209 dispatch: 210 CompositeExplicitAutograd: _functional_sym_constrain_range_for_size 211 212- func: _make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 213 dispatch: 214 CPU: _make_dep_token_cpu 215 216- func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) 217 variants: method 218 219- func: _use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool 220 device_check: NoCheck # Tensor arguments allowed to be on different devices, see also _cudnn_ctc_loss 221 dispatch: 222 CUDA: _use_cudnn_ctc_loss 223 224- func: _use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool 225 device_check: NoCheck # Tensor arguments allowed to be on different devices, see also _cudnn_ctc_loss 226 dispatch: 227 CUDA: _use_cudnn_ctc_loss_tensor 228 229- func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) 230 device_check: NoCheck # log_probs is expected to be on CUDA while targets is expected to be on CPU 231 dispatch: 232 CUDA: _cudnn_ctc_loss 233 autogen: _cudnn_ctc_loss.out 234 235- func: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) 236 device_check: NoCheck # log_probs is expected to be on CUDA while targets is expected to be on CPU 237 dispatch: 238 CUDA: _cudnn_ctc_loss_tensor 239 240- func: _use_cudnn_rnn_flatten_weight() -> bool 241 242- func: _cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor 243 dispatch: 244 CUDA: _cudnn_rnn_flatten_weight 245 autogen: _cudnn_rnn_flatten_weight.out 246 247- func: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) 248 # rnn_tanh may or may not redispatch to _cudnn_rnn based on algorithm and build. Thus it might hit dispatch or kernel device check. 249 # Disable dispatch time device check for consistent behavior. 250 device_check: NoCheck 251 dispatch: 252 CUDA: _cudnn_rnn 253 autogen: _cudnn_rnn.out 254 tags: nondeterministic_seeded 255 256- func: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) 257 dispatch: 258 CUDA: _cudnn_rnn_backward 259 autogen: _cudnn_rnn_backward.out 260 261- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 262 dispatch: 263 CUDA: _cudnn_init_dropout_state 264 autogen: _cudnn_init_dropout_state.out 265 tags: nondeterministic_seeded 266 267- func: _debug_has_internal_overlap(Tensor self) -> int 268 variants: function 269 270- func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) 271 variants: function 272 dispatch: 273 CUDA: fused_dropout_cuda 274 tags: nondeterministic_seeded 275 autogen: _fused_dropout.out 276 277- func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor 278 variants: function 279 dispatch: 280 CUDA: masked_scale_cuda 281 autogen: _masked_scale.out 282 283- func: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) 284 variants: function 285 dispatch: 286 CPU: native_dropout_cpu 287 CUDA: native_dropout_cuda 288 NestedTensorCPU, NestedTensorCUDA: native_dropout_nested 289 tags: [nondeterministic_seeded, core] 290 autogen: native_dropout.out 291 292- func: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor 293 dispatch: 294 CPU, NestedTensorCPU, NestedTensorCUDA: native_dropout_backward 295 CUDA: native_dropout_backward_cuda 296 autogen: native_dropout_backward.out 297 tags: pointwise 298 299- func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) 300 301- func: _sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!) 302 303- func: _sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!) 304 305- func: _sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!) 306 307- func: _reshape_from_tensor(Tensor self, Tensor shape) -> Tensor 308 309- func: _shape_as_tensor(Tensor self) -> Tensor 310 311- func: dropout(Tensor input, float p, bool train) -> Tensor 312 tags: nondeterministic_seeded 313 314- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) 315 tags: nondeterministic_seeded 316 317- func: feature_dropout(Tensor input, float p, bool train) -> Tensor 318 tags: nondeterministic_seeded 319 320- func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) 321 tags: nondeterministic_seeded 322 323- func: alpha_dropout(Tensor input, float p, bool train) -> Tensor 324 tags: nondeterministic_seeded 325 326- func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) 327 tags: nondeterministic_seeded 328 329- func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor 330 tags: nondeterministic_seeded 331 332- func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) 333 tags: nondeterministic_seeded 334 335- func: abs(Tensor self) -> Tensor 336 device_check: NoCheck # TensorIterator 337 variants: function, method 338 dispatch: 339 CompositeExplicitAutograd: abs 340 SparseCPU, SparseCUDA: abs_sparse 341 SparseCsrCPU, SparseCsrCUDA: abs_sparse_csr 342 NestedTensorCPU, NestedTensorCUDA: NestedTensor_abs 343 tags: [core, pointwise] 344 345- func: abs_(Tensor(a!) self) -> Tensor(a!) 346 device_check: NoCheck # TensorIterator 347 variants: function, method 348 dispatch: 349 CompositeExplicitAutograd: abs_ 350 SparseCPU, SparseCUDA: abs_sparse_ 351 SparseCsrCPU, SparseCsrCUDA: abs_sparse_csr_ 352 NestedTensorCPU, NestedTensorCUDA: NestedTensor_abs_ 353 354- func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 355 device_check: NoCheck # TensorIterator 356 dispatch: 357 CPU, CUDA: abs_out 358 MPS: abs_out_mps 359 SparseCPU, SparseCUDA: abs_sparse_out 360 SparseCsrCPU, SparseCsrCUDA: abs_sparse_csr_out 361 tags: pointwise 362 363# Note [Adding an alias] 364# To add an alias do the following: 365# 366# 1) Copy the original functions native_functions.yaml entry, but replace the 367# original function's name with their own and delete any dispatch 368# keys for the aliases. Specifying a dispatch key will prevent 369# autograd from recording the operations the alias performs, which 370# will stop it from "inheriting" the original operation's autograd behavior. 371# 2) Implement the corresponding functions and have them redispatch to the 372# original function. 373# 3) Add docstrings to the new function that reference the original function, 374# and document the method as usual (if it exists.) 375# (See torch/_torch_docs.py and docs/source/torch.rst if adding a function, 376# torch/_tensor_docs.py and docs/source/tensors.rst if adding a method, 377# or module-specific doc bindings (like torch/linalg/__init__.py) if 378# adding an alias in a namespace.) 379# 4) Update torch/overrides.py consistent with the original function. 380# 5) Update the alias_map in torch/csrc/jit/passes/normalize_ops.cpp. 381# 6) Add aliases argument to existing OpInfo/UnaryUfuncInfo or create new OpInfo/UnaryUfuncInfo entry 382# in op_db list in torch/testing/_internal/common_methods_invocations.py 383# 384# See torch.absolute, an alias for torch.abs, as an example. 385# Absolute, alias for abs 386 387- func: absolute(Tensor self) -> Tensor 388 device_check: NoCheck # TensorIterator 389 variants: function, method 390 391- func: absolute_(Tensor(a!) self) -> Tensor(a!) 392 device_check: NoCheck # TensorIterator 393 variants: method 394 395- func: absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 396 device_check: NoCheck # TensorIterator 397 398- func: angle(Tensor self) -> Tensor 399 device_check: NoCheck # TensorIterator 400 variants: function, method 401 dispatch: 402 CPU, CUDA: angle 403 SparseCsrCPU, SparseCsrCUDA: angle_sparse_csr 404 tags: pointwise 405 406- func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 407 device_check: NoCheck # TensorIterator 408 dispatch: 409 CPU, CUDA: angle_out 410 SparseCsrCPU, SparseCsrCUDA: angle_sparse_csr_out 411 tags: pointwise 412 413- func: view_as_real(Tensor(a) self) -> Tensor(a) 414 variants: function 415 dispatch: 416 CPU, CUDA, MPS, Meta: view_as_real 417 418- func: view_as_complex(Tensor(a) self) -> Tensor(a) 419 variants: function 420 dispatch: 421 CPU, CUDA, MPS, Meta: view_as_complex 422 423- func: sgn(Tensor self) -> Tensor 424 variants: function, method 425 structured_delegate: sgn.out 426 dispatch: 427 SparseCPU, SparseCUDA: sgn_sparse 428 SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr 429 NestedTensorCPU, NestedTensorCUDA: NestedTensor_sgn 430 tags: pointwise 431 432- func: sgn_(Tensor(a!) self) -> Tensor(a!) 433 variants: method 434 structured_delegate: sgn.out 435 dispatch: 436 SparseCPU, SparseCUDA: sgn_sparse_ 437 SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr_ 438 NestedTensorCPU, NestedTensorCUDA: NestedTensor_sgn_ 439 tags: pointwise 440 441- func: sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 442 structured: True 443 structured_inherits: TensorIteratorBase 444 dispatch: 445 CPU, CUDA: sgn_out 446 MPS: sgn_out_mps 447 SparseCPU, SparseCUDA: sgn_sparse_out 448 SparseCsrCPU, SparseCsrCUDA: sgn_sparse_csr_out 449 tags: pointwise 450 451- func: chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor 452 variants: method 453 454- func: real(Tensor(a) self) -> Tensor(a) 455 device_check: NoCheck # TensorIterator 456 variants: function 457 458- func: imag(Tensor(a) self) -> Tensor(a) 459 device_check: NoCheck # TensorIterator 460 variants: function 461 462- func: _conj(Tensor(a) self) -> Tensor(a) 463 variants: function, method 464 dispatch: 465 CompositeExplicitAutograd: _conj 466 467- func: conj(Tensor(a) self) -> Tensor(a) 468 variants: function, method 469 manual_cpp_binding: True 470 471- func: _conj_physical(Tensor self) -> Tensor 472 variants: function, method 473 dispatch: 474 CompositeExplicitAutograd: _conj_physical 475 SparseCsrCPU, SparseCsrCUDA: conj_physical_sparse_csr 476 autogen: _conj_physical.out 477 478- func: conj_physical(Tensor self) -> Tensor 479 variants: function, method 480 tags: pointwise 481 482- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 483 dispatch: 484 CPU, CUDA: conj_physical_out 485 MPS: conj_physical_out_mps 486 SparseCPU, SparseCUDA: conj_physical_out_sparse 487 SparseCsrCPU, SparseCsrCUDA: conj_physical_sparse_csr_out 488 tags: pointwise 489 490- func: conj_physical_(Tensor(a!) self) -> Tensor(a!) 491 variants: function, method 492 dispatch: 493 CompositeExplicitAutograd: conj_physical_ 494 SparseCsrCPU, SparseCsrCUDA: conj_physical_sparse_csr_ 495 tags: pointwise 496 497- func: resolve_conj(Tensor(a) self) -> Tensor(a) 498 variants: function, method 499 500- func: resolve_neg(Tensor(a) self) -> Tensor(a) 501 variants: function, method 502 503- func: _neg_view(Tensor(a) self) -> Tensor(a) 504 variants: function, method 505 dispatch: 506 CompositeExplicitAutograd: _neg_view 507 508- func: acos(Tensor self) -> Tensor 509 device_check: NoCheck # TensorIterator 510 variants: function, method 511 structured_delegate: acos.out 512 tags: [core, pointwise] 513 514- func: acos_(Tensor(a!) self) -> Tensor(a!) 515 device_check: NoCheck # TensorIterator 516 variants: function, method 517 structured_delegate: acos.out 518 tags: pointwise 519 520- func: acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 521 device_check: NoCheck # TensorIterator 522 structured: True 523 structured_inherits: TensorIteratorBase 524 dispatch: 525 CPU, CUDA: acos_out 526 MPS: acos_out_mps 527 tags: pointwise 528 529# arccos, alias of acos 530- func: arccos(Tensor self) -> Tensor 531 variants: function, method 532 533- func: arccos_(Tensor(a!) self) -> Tensor(a!) 534 variants: function, method 535 536- func: arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 537 538- func: avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor 539 tags: core 540 541- func: adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor 542 tags: core 543 544# Return: (Tensor output, Tensor indices) 545- func: adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) 546 547- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 548 device_check: NoCheck # TensorIterator 549 structured_delegate: add.out 550 variants: function, method 551 dispatch: 552 SparseCPU, SparseCUDA, SparseMeta: add_sparse 553 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr 554 MkldnnCPU: mkldnn_add 555 ZeroTensor: add_zerotensor 556 NestedTensorCPU, NestedTensorCUDA: NestedTensor_add_Tensor 557 tags: [core, pointwise] 558 559- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) 560 device_check: NoCheck # TensorIterator 561 variants: method 562 structured_delegate: add.out 563 dispatch: 564 SparseCPU, SparseCUDA, SparseMeta: add_sparse_ 565 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr_ 566 MkldnnCPU: mkldnn_add_ 567 NestedTensorCPU, NestedTensorCUDA: NestedTensor_add__Tensor 568 tags: pointwise 569 570- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 571 device_check: NoCheck # TensorIterator 572 structured: True 573 structured_inherits: TensorIteratorBase 574 ufunc_inner_loop: 575 Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf) 576 ScalarOnly: add (Bool) 577 dispatch: 578 SparseCPU, SparseMeta: add_out_sparse_cpu 579 SparseCUDA: add_out_sparse_cuda 580 SparseCsrCPU, SparseCsrMeta: add_out_sparse_compressed_cpu 581 SparseCsrCUDA: add_out_sparse_compressed_cuda 582 MkldnnCPU: mkldnn_add_out 583 MPS: add_out_mps 584 tags: pointwise 585 586- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 587 variants: function 588 dispatch: 589 CPU: add_relu 590 591- func: _add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) 592 variants: function 593 dispatch: 594 CPU: add_relu_ 595 596- func: _add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 597 variants: function 598 dispatch: 599 CPU: add_relu_out 600 601- func: _add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor 602 variants: function 603 dispatch: 604 CPU: add_relu 605 606- func: _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) 607 variants: function 608 dispatch: 609 CPU: add_relu_ 610 autogen: _add_relu.Scalar_out 611 612# For C++ only, until we have conversion from C++ numbers to Tensor 613- func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor 614 device_check: NoCheck # TensorIterator 615 variants: function, method 616 dispatch: 617 CompositeExplicitAutograd: add 618 tags: [core, pointwise] 619 620- func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) 621 device_check: NoCheck # TensorIterator 622 variants: method 623 dispatch: 624 CompositeExplicitAutograd: add_ 625 autogen: add.Scalar_out 626 tags: pointwise 627 628- func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor 629 structured_delegate: addmv.out 630 variants: function, method 631 632- func: addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) 633 structured_delegate: addmv.out 634 variants: function, method 635 636- func: addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 637 structured: True 638 dispatch: 639 CPU: addmv_out_cpu 640 CUDA: addmv_out_cuda 641 MPS: addmv_out_mps 642 SparseCsrCPU: addmv_out_sparse_compressed 643 SparseCsrCUDA: addmv_out_sparse_compressed_cuda 644 645- func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 646 variants: function, method 647 dispatch: 648 CPU, CUDA: addr 649 MPS: addr_mps 650 CompositeExplicitAutograd: math_addr 651 652- func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) 653 variants: method 654 dispatch: 655 CompositeExplicitAutograd: addr_ 656 657- func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 658 dispatch: 659 CPU, CUDA: addr_out 660 MPS: addr_out_mps 661 CompositeExplicitAutograd: math_addr_out 662 663- func: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor 664 variants: function 665 dispatch: 666 CompositeExplicitAutograd: affine_grid_generator 667 autogen: affine_grid_generator.out 668 669- func: affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor 670 variants: function 671 672- func: _is_all_true(Tensor self) -> Tensor 673 variants: function, method 674 dispatch: 675 CompositeExplicitAutograd: _is_all_true 676 677- func: _is_any_true(Tensor self) -> Tensor 678 variants: function, method 679 dispatch: 680 CompositeExplicitAutograd: _is_any_true 681 682# Note: this function is only for testing. 683- func: _test_check_tensor(Tensor self) -> Tensor 684 variants: function 685 686# Note; this function is only for testing 687- func: _test_functorch_fallback(Tensor self, Tensor other) -> Tensor 688 variants: function 689 dispatch: 690 CPU: _test_functorch_fallback 691 autogen: _test_functorch_fallback.out 692 693- func: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor 694 device_check: NoCheck # TensorIterator 695 structured_delegate: all.out 696 variants: function, method 697 698- func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor 699 device_check: NoCheck # TensorIterator 700 structured_delegate: all.dims_out 701 variants: function, method 702 cpp_no_default_args: ['dim'] 703 dispatch: 704 CompositeExplicitAutograd: all_dims_default 705 706- func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 707 device_check: NoCheck # TensorIterator 708 structured: True 709 dispatch: 710 CPU, CUDA: all_out 711 MPS: all_out_mps 712 713- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 714 device_check: NoCheck # TensorIterator 715 structured: True 716 dispatch: 717 CPU, CUDA: all_dims_out 718 CompositeExplicitAutograd: all_dims_out_default 719 cpp_no_default_args: ['dim'] 720 721- func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor 722 device_check: NoCheck # TensorIterator 723 variants: function, method 724 725- func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 726 device_check: NoCheck # TensorIterator 727 728- func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool 729 variants: function, method 730 tags: data_dependent_output 731 dispatch: 732 CompositeExplicitAutograd: allclose 733 734- func: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor 735 device_check: NoCheck # TensorIterator 736 structured_delegate: any.out 737 variants: function, method 738 tags: core 739 740- func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor 741 device_check: NoCheck # TensorIterator 742 structured_delegate: any.dims_out 743 variants: function, method 744 cpp_no_default_args: ['dim'] 745 tags: core 746 dispatch: 747 CompositeExplicitAutograd: any_dims_default 748 749- func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 750 device_check: NoCheck # TensorIterator 751 structured: True 752 dispatch: 753 CPU, CUDA: any_out 754 MPS: any_out_mps 755 756- func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 757 device_check: NoCheck # TensorIterator 758 structured: True 759 dispatch: 760 CPU, CUDA: any_dims_out 761 CompositeExplicitAutograd: any_dims_out_default 762 cpp_no_default_args: ['dim'] 763 764- func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor 765 device_check: NoCheck # TensorIterator 766 variants: function, method 767 768- func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 769 device_check: NoCheck # TensorIterator 770 771- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 772 dispatch: 773 CompositeExplicitAutograd: arange 774 775- func: arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 776 dispatch: 777 CompositeExplicitAutograd: arange 778 779# This operator should be named `arange.start_out` if following the naming convention. However that 780# name is already taken. Disabled because of CI job failures. 781# FIXME: enable this 782#- func: arange.start_out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!) 783# dispatch: 784# CompositeExplicitAutograd: arange_start_out 785 786- func: arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 787 dispatch: 788 CompositeExplicitAutograd: arange 789 cpp_no_default_args: ['step'] 790 tags: core 791 792- func: arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) 793 dispatch: 794 CompositeExplicitAutograd: arange_out 795 796- func: arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) 797 dispatch: 798 CPU, Meta: arange_out 799 CUDA: arange_cuda_out 800 MPS: arange_mps_out 801 cpp_no_default_args: ['step'] 802 803# This function is a temporary hack to allow tracing of arange like constructs with dynamic 804# bounds on arange. Normal arange is not traceable because it does not take any tensor inputs; 805# if the range you need is based on another tensor, calling this function directly will 806# preserve tracing. Get rid of this when arange can directly take tensors for bounds 807# (so that it can be traced directly). 808- func: _dim_arange(Tensor like, int dim) -> Tensor 809 810- func: argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor 811 structured_delegate: argmax.out 812 device_check: NoCheck # TensorIterator 813 variants: function, method 814 tags: core 815 816- func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 817 structured: True 818 dispatch: 819 CPU, CUDA: argmax_out 820 MPS: argmax_out_mps 821 822- func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor 823 structured_delegate: argmin.out 824 device_check: NoCheck # TensorIterator 825 variants: function, method 826 tags: core 827 828- func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 829 structured: True 830 dispatch: 831 CPU, CUDA: argmin_out 832 MPS: argmin_out_mps 833 834- func: acosh(Tensor self) -> Tensor 835 variants: function, method 836 structured_delegate: acosh.out 837 tags: [core, pointwise] 838 839- func: acosh_(Tensor(a!) self) -> Tensor(a!) 840 variants: function, method 841 structured_delegate: acosh.out 842 tags: pointwise 843 844- func: acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 845 structured: True 846 structured_inherits: TensorIteratorBase 847 dispatch: 848 CPU, CUDA: acosh_out 849 MPS: acosh_out_mps 850 tags: pointwise 851# arccosh, alias for acosh 852 853- func: arccosh(Tensor self) -> Tensor 854 variants: function, method 855 856- func: arccosh_(Tensor(a!) self) -> Tensor(a!) 857 variants: function, method 858 859- func: arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 860 861- func: asinh(Tensor self) -> Tensor 862 variants: function, method 863 structured_delegate: asinh.out 864 dispatch: 865 SparseCPU, SparseCUDA: asinh_sparse 866 SparseCsrCPU, SparseCsrCUDA: asinh_sparse_csr 867 tags: [core, pointwise] 868 869- func: asinh_(Tensor(a!) self) -> Tensor(a!) 870 variants: function, method 871 structured_delegate: asinh.out 872 dispatch: 873 SparseCPU, SparseCUDA: asinh_sparse_ 874 SparseCsrCPU, SparseCsrCUDA: asinh_sparse_csr_ 875 tags: pointwise 876 877- func: asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 878 structured: True 879 structured_inherits: TensorIteratorBase 880 dispatch: 881 CPU, CUDA: asinh_out 882 MPS: asinh_out_mps 883 SparseCPU, SparseCUDA: asinh_sparse_out 884 SparseCsrCPU, SparseCsrCUDA: asinh_sparse_csr_out 885 tags: pointwise 886 887# arcsinh, alias for asinh 888- func: arcsinh(Tensor self) -> Tensor 889 variants: function, method 890 891- func: arcsinh_(Tensor(a!) self) -> Tensor(a!) 892 variants: function, method 893 894- func: arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 895 896- func: atanh(Tensor self) -> Tensor 897 structured_delegate: atanh.out 898 variants: function, method 899 dispatch: 900 SparseCPU, SparseCUDA: atanh_sparse 901 SparseCsrCPU, SparseCsrCUDA: atanh_sparse_csr 902 tags: [core, pointwise] 903 904- func: atanh_(Tensor(a!) self) -> Tensor(a!) 905 structured_delegate: atanh.out 906 variants: function, method 907 dispatch: 908 SparseCPU, SparseCUDA: atanh_sparse_ 909 SparseCsrCPU, SparseCsrCUDA: atanh_sparse_csr_ 910 tags: pointwise 911 912- func: atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 913 structured: True 914 structured_inherits: TensorIteratorBase 915 dispatch: 916 CPU, CUDA: atanh_out 917 MPS: atanh_out_mps 918 SparseCPU, SparseCUDA: atanh_sparse_out 919 SparseCsrCPU, SparseCsrCUDA: atanh_sparse_csr_out 920 tags: pointwise 921# arctanh, alias for atanh 922 923- func: arctanh(Tensor self) -> Tensor 924 variants: function, method 925 926- func: arctanh_(Tensor(a!) self) -> Tensor(a!) 927 variants: function, method 928 929- func: arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 930 931- func: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) 932 variants: function, method 933 dispatch: 934 ZeroTensor, CPU, CUDA: as_strided_tensorimpl 935 Meta: as_strided_tensorimpl_meta_symint 936 MPS: as_strided_tensorimpl_mps 937 QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl 938 device_check: NoCheck 939 device_guard: False 940 tags: core 941 942- func: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) 943 use_const_ref_for_mutable_tensors: True 944 variants: function, method 945 device_check: NoCheck 946 device_guard: False 947 tags: inplace_view 948 dispatch: 949 CompositeExplicitAutogradNonFunctional: as_strided__symint 950 951- func: asin(Tensor self) -> Tensor 952 device_check: NoCheck # TensorIterator 953 variants: function, method 954 structured_delegate: asin.out 955 dispatch: 956 SparseCPU, SparseCUDA: asin_sparse 957 SparseCsrCPU, SparseCsrCUDA: asin_sparse_csr 958 tags: [core, pointwise] 959 960- func: asin_(Tensor(a!) self) -> Tensor(a!) 961 device_check: NoCheck # TensorIterator 962 variants: function, method 963 structured_delegate: asin.out 964 dispatch: 965 SparseCPU, SparseCUDA: asin_sparse_ 966 SparseCsrCPU, SparseCsrCUDA: asin_sparse_csr_ 967 tags: pointwise 968 969- func: asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 970 device_check: NoCheck # TensorIterator 971 structured: True 972 structured_inherits: TensorIteratorBase 973 dispatch: 974 CPU, CUDA: asin_out 975 MPS: asin_out_mps 976 SparseCPU, SparseCUDA: asin_sparse_out 977 SparseCsrCPU, SparseCsrCUDA: asin_sparse_csr_out 978 tags: pointwise 979 980# arcsin, alias of asin 981- func: arcsin(Tensor self) -> Tensor 982 variants: function, method 983 984- func: arcsin_(Tensor(a!) self) -> Tensor(a!) 985 variants: function, method 986 987- func: arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 988 989- func: atan(Tensor self) -> Tensor 990 device_check: NoCheck # TensorIterator 991 structured_delegate: atan.out 992 variants: function, method 993 dispatch: 994 SparseCPU, SparseCUDA: atan_sparse 995 SparseCsrCPU, SparseCsrCUDA: atan_sparse_csr 996 tags: [core, pointwise] 997 998- func: atan_(Tensor(a!) self) -> Tensor(a!) 999 device_check: NoCheck # TensorIterator 1000 structured_delegate: atan.out 1001 variants: function, method 1002 dispatch: 1003 SparseCPU, SparseCUDA: atan_sparse_ 1004 SparseCsrCPU, SparseCsrCUDA: atan_sparse_csr_ 1005 tags: pointwise 1006 1007- func: atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 1008 device_check: NoCheck # TensorIterator 1009 structured: True 1010 structured_inherits: TensorIteratorBase 1011 dispatch: 1012 CPU, CUDA: atan_out 1013 MPS: atan_out_mps 1014 SparseCPU, SparseCUDA: atan_sparse_out 1015 SparseCsrCPU, SparseCsrCUDA: atan_sparse_csr_out 1016 tags: pointwise 1017 1018# arctan, alias of atan 1019- func: arctan(Tensor self) -> Tensor 1020 variants: function, method 1021 1022- func: arctan_(Tensor(a!) self) -> Tensor(a!) 1023 variants: function, method 1024 1025- func: arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 1026 1027- func: atleast_1d(Tensor self) -> Tensor 1028 variants: function 1029 1030- func: atleast_1d.Sequence(Tensor[] tensors) -> Tensor[] 1031 1032- func: atleast_2d(Tensor self) -> Tensor 1033 variants: function 1034 1035- func: atleast_2d.Sequence(Tensor[] tensors) -> Tensor[] 1036 variants: function 1037 1038- func: atleast_3d(Tensor self) -> Tensor 1039 variants: function 1040 1041- func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[] 1042 variants: function 1043 1044- func: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 1045 variants: function, method 1046 structured_delegate: baddbmm.out 1047 1048- func: baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) 1049 variants: method 1050 structured_delegate: baddbmm.out 1051 1052- func: baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 1053 structured: True 1054 variants: function 1055 dispatch: 1056 CPU: baddbmm_out_cpu 1057 CUDA: baddbmm_out_cuda 1058 MPS: baddbmm_out_mps 1059 SparseCsrCUDA: baddbmm_out_sparse_csr_cuda 1060 1061- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 1062 dispatch: 1063 CompositeExplicitAutograd: bartlett_window 1064 autogen: bartlett_window.out 1065 1066- func: bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 1067 dispatch: 1068 CompositeExplicitAutograd: bartlett_window 1069 autogen: bartlett_window.periodic_out 1070 1071- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor 1072 1073- func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor 1074 dispatch: 1075 QuantizedCPU: quantized_batch_norm 1076 autogen: quantized_batch_norm.out 1077 1078- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int) 1079 1080- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor) 1081 1082# Sample bernoulli with values in `self` as probability. 1083- func: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor 1084 device_check: NoCheck # TensorIterator 1085 variants: function, method 1086 dispatch: 1087 CompositeExplicitAutograd: bernoulli 1088 tags: nondeterministic_seeded 1089 1090- func: bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) 1091 device_check: NoCheck # TensorIterator 1092 variants: function 1093 tags: nondeterministic_seeded 1094 dispatch: 1095 CPU, CUDA: bernoulli_out 1096 MPS: bernoulli_out_mps 1097 1098- func: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) 1099 device_check: NoCheck # TensorIterator 1100 variants: method 1101 tags: nondeterministic_seeded 1102 dispatch: 1103 CPU, CUDA: bernoulli_ 1104 MPS: bernoulli_mps_ 1105 autogen: bernoulli.Tensor, bernoulli.Tensor_out 1106 1107- func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) 1108 device_check: NoCheck # TensorIterator 1109 variants: method 1110 tags: nondeterministic_seeded 1111 dispatch: 1112 CPU, CUDA: bernoulli_ 1113 MPS: bernoulli_mps_ 1114 autogen: bernoulli.float_out 1115 1116# Note [bernoulli.p schema] 1117# We should probably just fix the overload ambiguity by appending a _functional to the C++ API name (BC breaking) 1118# This out-of-place version isn't used explicitly, but needed by jit. 1119# There is no default valid on `p` here because it would introduce ambiguity 1120# with `bernoulli(Tensor self, *, Generator? generator=None)` declaration. 1121- func: bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor 1122 device_check: NoCheck # TensorIterator 1123 variants: function, method 1124 tags: nondeterministic_seeded 1125 dispatch: 1126 CompositeExplicitAutogradNonFunctional: bernoulli 1127 1128- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor 1129 1130- func: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor 1131 device_check: NoCheck # TensorIterator 1132 python_module: nn 1133 variants: function 1134 dispatch: 1135 CPU: binary_cross_entropy_cpu 1136 CUDA: binary_cross_entropy_cuda 1137 MPS: binary_cross_entropy_mps 1138 1139- func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) 1140 device_check: NoCheck # TensorIterator 1141 python_module: nn 1142 variants: function 1143 dispatch: 1144 CPU: binary_cross_entropy_out_cpu 1145 CUDA: binary_cross_entropy_out_cuda 1146 MPS: binary_cross_entropy_out_mps 1147 1148- func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor 1149 python_module: nn 1150 variants: function 1151 dispatch: 1152 CPU: binary_cross_entropy_backward_cpu 1153 CUDA: binary_cross_entropy_backward_cuda 1154 MPS: binary_cross_entropy_backward_mps 1155 1156- func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) 1157 python_module: nn 1158 variants: function 1159 dispatch: 1160 CPU: binary_cross_entropy_backward_out_cpu 1161 CUDA: binary_cross_entropy_backward_out_cuda 1162 MPS: binary_cross_entropy_backward_out_mps 1163 1164- func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor 1165 device_check: NoCheck # TensorIterator 1166 variants: function 1167 dispatch: 1168 CompositeExplicitAutograd: binary_cross_entropy_with_logits 1169 autogen: binary_cross_entropy_with_logits.out 1170 1171- func: bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor 1172 variants: function, method 1173 dispatch: 1174 CPU: _bincount_cpu 1175 CUDA: _bincount_cuda 1176 MPS: _bincount_mps 1177 tags: dynamic_output_shape 1178 autogen: bincount.out 1179 1180- func: bitwise_not(Tensor self) -> Tensor 1181 device_check: NoCheck # TensorIterator 1182 structured_delegate: bitwise_not.out 1183 variants: function, method 1184 tags: [core, pointwise] 1185 1186- func: bitwise_not_(Tensor(a!) self) -> Tensor(a!) 1187 device_check: NoCheck # TensorIterator 1188 structured_delegate: bitwise_not.out 1189 variants: method 1190 tags: pointwise 1191 1192- func: bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 1193 device_check: NoCheck # TensorIterator 1194 structured: True 1195 structured_inherits: TensorIteratorBase 1196 dispatch: 1197 CPU, CUDA: bitwise_not_out 1198 MPS: bitwise_not_out_mps 1199 tags: pointwise 1200 1201- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 1202 device_check: NoCheck # TensorIterator 1203 structured: True 1204 structured_inherits: TensorIteratorBase 1205 dispatch: 1206 CPU, CUDA, MPS: copysign_out 1207 tags: pointwise 1208 1209- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor 1210 device_check: NoCheck # TensorIterator 1211 variants: function, method 1212 structured_delegate: copysign.out 1213 tags: pointwise 1214 1215- func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 1216 device_check: NoCheck # TensorIterator 1217 variants: method 1218 structured_delegate: copysign.out 1219 1220- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor 1221 variants: function, method 1222 dispatch: 1223 CompositeExplicitAutograd: copysign 1224 tags: pointwise 1225 1226- func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 1227 variants: method 1228 dispatch: 1229 CompositeExplicitAutograd: copysign_ 1230 1231- func: copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 1232 dispatch: 1233 CompositeExplicitAutograd: copysign_out 1234 tags: pointwise 1235 1236- func: _lazy_clone(Tensor self) -> Tensor 1237 # Like clone, but the copy takes place lazily, only if either the 1238 # input or the output are written. 1239 variants: function, method 1240 dispatch: 1241 CompositeExplicitAutograd: _lazy_clone 1242 1243- func: logical_not(Tensor self) -> Tensor 1244 device_check: NoCheck # TensorIterator 1245 variants: function, method 1246 dispatch: 1247 CompositeExplicitAutograd: logical_not 1248 NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not 1249 tags: [core, pointwise] 1250 1251- func: logical_not_(Tensor(a!) self) -> Tensor(a!) 1252 device_check: NoCheck # TensorIterator 1253 variants: method 1254 dispatch: 1255 CompositeExplicitAutograd: logical_not_ 1256 NestedTensorCPU, NestedTensorCUDA: NestedTensor_logical_not_ 1257 tags: pointwise 1258 1259- func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 1260 device_check: NoCheck # TensorIterator 1261 dispatch: 1262 CPU, CUDA: logical_not_out 1263 MPS: logical_not_out_mps 1264 tags: pointwise 1265 1266- func: logical_xor(Tensor self, Tensor other) -> Tensor 1267 device_check: NoCheck # TensorIterator 1268 variants: function, method 1269 dispatch: 1270 CompositeExplicitAutograd: logical_xor 1271 tags: [core, pointwise] 1272 1273- func: logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) 1274 device_check: NoCheck # TensorIterator 1275 variants: method 1276 dispatch: 1277 CompositeExplicitAutograd: logical_xor_ 1278 tags: pointwise 1279 1280- func: logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 1281 device_check: NoCheck # TensorIterator 1282 dispatch: 1283 CPU, CUDA: logical_xor_out 1284 MPS: logical_xor_out_mps 1285 tags: pointwise 1286 1287- func: logical_and(Tensor self, Tensor other) -> Tensor 1288 device_check: NoCheck # TensorIterator 1289 variants: function, method 1290 dispatch: 1291 CompositeExplicitAutograd: logical_and 1292 tags: [core, pointwise] 1293 1294- func: logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!) 1295 device_check: NoCheck # TensorIterator 1296 variants: method 1297 dispatch: 1298 CompositeExplicitAutograd: logical_and_ 1299 tags: pointwise 1300 1301- func: logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 1302 device_check: NoCheck # TensorIterator 1303 dispatch: 1304 CPU, CUDA: logical_and_out 1305 MPS: logical_and_out_mps 1306 tags: pointwise 1307 1308- func: logical_or(Tensor self, Tensor other) -> Tensor 1309 device_check: NoCheck # TensorIterator 1310 variants: function, method 1311 dispatch: 1312 CompositeExplicitAutograd: logical_or 1313 tags: [core, pointwise] 1314 1315- func: logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!) 1316 device_check: NoCheck # TensorIterator 1317 variants: method 1318 dispatch: 1319 CompositeExplicitAutograd: logical_or_ 1320 tags: pointwise 1321 1322- func: logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 1323 device_check: NoCheck # TensorIterator 1324 dispatch: 1325 CPU, CUDA: logical_or_out 1326 MPS: logical_or_out_mps 1327 tags: pointwise 1328 1329- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 1330 dispatch: 1331 CompositeExplicitAutograd: blackman_window 1332 autogen: blackman_window.out 1333 1334- func: blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 1335 dispatch: 1336 CompositeExplicitAutograd: blackman_window 1337 autogen: blackman_window.periodic_out 1338 1339- func: bmm(Tensor self, Tensor mat2) -> Tensor 1340 structured_delegate: bmm.out 1341 variants: function, method 1342 dispatch: 1343 SparseCPU: bmm_sparse_cpu 1344 SparseCUDA: bmm_sparse_cuda 1345 NestedTensorCPU: bmm_nested 1346 NestedTensorCUDA: bmm_nested_cuda 1347 tags: core 1348 1349- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) 1350 structured: True 1351 variants: function 1352 dispatch: 1353 CPU: bmm_out_cpu 1354 CUDA: bmm_out_cuda 1355 MPS: bmm_out_mps 1356 SparseCPU: bmm_out_sparse_cpu 1357 SparseCUDA: bmm_out_sparse_cuda 1358 SparseCsrCUDA: bmm_out_sparse_csr_cuda 1359 1360- func: broadcast_tensors(Tensor[] tensors) -> Tensor[] 1361 device_check: NoCheck 1362 device_guard: False 1363 1364- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) 1365 variants: function, method 1366 dispatch: 1367 CompositeImplicitAutograd: broadcast_to_symint 1368 1369- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) 1370 variants: function 1371 dispatch: 1372 SparseCPU, SparseCUDA: sparse_broadcast_to 1373 1374- func: cat(Tensor[] tensors, int dim=0) -> Tensor 1375 structured_delegate: cat.out 1376 dispatch: 1377 SparseCPU, SparseCUDA: cat_sparse 1378 QuantizedCPU: cat_quantized_cpu 1379 NestedTensorCPU, NestedTensorCUDA: cat_nested 1380 tags: core 1381 1382- func: cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) 1383 structured: True 1384 precomputed: 1385 - dim -> int dim, int valid, bool all_contiguous, bool all_same_dtype, bool all_same_sizes_and_stride, MemoryFormat memory_format 1386 dispatch: 1387 CPU: cat_out_cpu 1388 CUDA: cat_out_cuda 1389 MPS: cat_out_mps 1390 QuantizedCPU: cat_out_quantized_cpu 1391 1392- func: cat.names(Tensor[] tensors, Dimname dim) -> Tensor 1393 1394- func: cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) 1395 1396# alias for torch.cat 1397- func: concat(Tensor[] tensors, int dim=0) -> Tensor 1398 1399- func: concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) 1400 1401- func: concat.names(Tensor[] tensors, Dimname dim) -> Tensor 1402 1403- func: concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) 1404 1405# alias for torch.cat 1406- func: concatenate(Tensor[] tensors, int dim=0) -> Tensor 1407 1408- func: concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) 1409 1410- func: concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor 1411 1412- func: concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) 1413 1414- func: block_diag(Tensor[] tensors) -> Tensor 1415 variants: function 1416 dispatch: 1417 CompositeExplicitAutograd: block_diag 1418 autogen: block_diag.out 1419 1420- func: ceil(Tensor self) -> Tensor 1421 device_check: NoCheck # TensorIterator 1422 structured_delegate: ceil.out 1423 variants: function, method 1424 dispatch: 1425 SparseCPU, SparseCUDA: ceil_sparse 1426 SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr 1427 tags: [core, pointwise] 1428 1429- func: ceil_(Tensor(a!) self) -> Tensor(a!) 1430 device_check: NoCheck # TensorIterator 1431 structured_delegate: ceil.out 1432 variants: function, method 1433 dispatch: 1434 SparseCPU, SparseCUDA: ceil_sparse_ 1435 SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr_ 1436 tags: pointwise 1437 1438- func: ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 1439 device_check: NoCheck # TensorIterator 1440 structured: True 1441 structured_inherits: TensorIteratorBase 1442 dispatch: 1443 CPU, CUDA: ceil_out 1444 MPS: ceil_out_mps 1445 SparseCPU, SparseCUDA: ceil_sparse_out 1446 SparseCsrCPU, SparseCsrCUDA: ceil_sparse_csr_out 1447 tags: pointwise 1448 1449# alias for torch.linalg.multi_dot 1450- func: chain_matmul(Tensor[] matrices) -> Tensor 1451 variants: function 1452 1453# alias for torch.linalg.multi_dot 1454- func: chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!) 1455 1456- func: unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[] 1457 variants: function, method 1458 device_check: NoCheck 1459 device_guard: False 1460 1461- func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] 1462 variants: function, method 1463 device_check: NoCheck 1464 device_guard: False 1465 dispatch: 1466 CompositeImplicitAutograd: chunk 1467 NestedTensorCPU, NestedTensorCUDA: chunk_nested_tensor 1468 1469- func: tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] 1470 variants: function, method 1471 dispatch: 1472 CompositeImplicitAutograd: tensor_split_sections_symint 1473 1474- func: tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] 1475 variants: function, method 1476 dispatch: 1477 CompositeImplicitAutograd: tensor_split_indices_symint 1478 1479- func: tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] 1480 variants: function, method 1481 1482- func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor 1483 device_check: NoCheck # TensorIterator 1484 variants: function, method 1485 cpp_no_default_args: ['min'] 1486 structured_delegate: clamp.out 1487 dispatch: 1488 QuantizedCPU: clamp_quantized_cpu 1489 tags: [core, pointwise] 1490 1491- func: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor 1492 variants: function, method 1493 structured_delegate: clamp.Tensor_out 1494 tags: [core, pointwise] 1495 1496- func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) 1497 device_check: NoCheck # TensorIterator 1498 variants: function, method 1499 cpp_no_default_args: ['min'] 1500 structured_delegate: clamp.out 1501 tags: pointwise 1502 1503- func: clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) 1504 variants: function, method 1505 structured_delegate: clamp.Tensor_out 1506 tags: pointwise 1507 1508- func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) 1509 device_check: NoCheck # TensorIterator 1510 cpp_no_default_args: ['min'] 1511 structured: True 1512 structured_inherits: TensorIteratorBase 1513 dispatch: 1514 CPU, CUDA: clamp_out 1515 MPS: clamp_out_mps 1516 tags: pointwise 1517 1518- func: clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) 1519 device_check: NoCheck # TensorIterator 1520 structured: True 1521 structured_inherits: TensorIteratorBase 1522 dispatch: 1523 CPU, CUDA: clamp_Tensor_out 1524 MPS: clamp_Tensor_out_mps 1525 tags: pointwise 1526 1527- func: clamp_max(Tensor self, Scalar max) -> Tensor 1528 device_check: NoCheck # TensorIterator 1529 variants: function, method 1530 structured_delegate: clamp_max.out 1531 tags: pointwise 1532 1533- func: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor 1534 variants: function, method 1535 structured_delegate: clamp_max.Tensor_out 1536 tags: pointwise 1537 1538- func: clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) 1539 device_check: NoCheck # TensorIterator 1540 variants: function, method 1541 structured_delegate: clamp_max.out 1542 tags: pointwise 1543 1544- func: clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!) 1545 variants: function, method 1546 structured_delegate: clamp_max.Tensor_out 1547 tags: pointwise 1548 1549- func: clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) 1550 device_check: NoCheck # TensorIterator 1551 structured: True 1552 structured_inherits: TensorIteratorBase 1553 dispatch: 1554 CPU, CUDA: clamp_max_out 1555 MPS: clamp_max_out_mps 1556 tags: pointwise 1557 1558- func: clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!) 1559 device_check: NoCheck # TensorIterator 1560 structured: True 1561 structured_inherits: TensorIteratorBase 1562 dispatch: 1563 CPU, CUDA: clamp_max_Tensor_out 1564 MPS: clamp_max_Tensor_out_mps 1565 tags: pointwise 1566 1567- func: clamp_min(Tensor self, Scalar min) -> Tensor 1568 device_check: NoCheck # TensorIterator 1569 variants: function, method 1570 structured_delegate: clamp_min.out 1571 tags: pointwise 1572 1573- func: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor 1574 variants: function, method 1575 structured_delegate: clamp_min.Tensor_out 1576 tags: pointwise 1577 1578- func: clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) 1579 device_check: NoCheck # TensorIterator 1580 variants: function, method 1581 structured_delegate: clamp_min.out 1582 tags: pointwise 1583 1584- func: clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!) 1585 variants: function, method 1586 structured_delegate: clamp_min.Tensor_out 1587 tags: pointwise 1588 1589- func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) 1590 device_check: NoCheck # TensorIterator 1591 structured: True 1592 structured_inherits: TensorIteratorBase 1593 dispatch: 1594 CPU, CUDA: clamp_min_out 1595 MPS: clamp_min_out_mps 1596 tags: pointwise 1597 1598- func: clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!) 1599 device_check: NoCheck # TensorIterator 1600 structured: True 1601 structured_inherits: TensorIteratorBase 1602 dispatch: 1603 CPU, CUDA: clamp_min_Tensor_out 1604 MPS: clamp_min_Tensor_out_mps 1605 tags: pointwise 1606 1607# clip is an alias for clamp 1608- func: clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor 1609 cpp_no_default_args: ['min'] 1610 variants: function, method 1611 tags: pointwise 1612 1613- func: clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor 1614 variants: function, method 1615 tags: pointwise 1616 1617- func: clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) 1618 cpp_no_default_args: ['min'] 1619 variants: function, method 1620 tags: pointwise 1621 1622- func: clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) 1623 variants: function, method 1624 tags: pointwise 1625 1626- func: clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) 1627 cpp_no_default_args: ['min'] 1628 tags: pointwise 1629 1630- func: clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) 1631 1632- func: cudnn_is_acceptable(Tensor self) -> bool 1633 device_check: NoCheck 1634 device_guard: False 1635 1636- func: complex(Tensor real, Tensor imag) -> Tensor 1637 variants: function 1638 dispatch: 1639 CompositeExplicitAutograd: complex 1640 1641- func: complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!) 1642 dispatch: 1643 CPU, CUDA: complex_out 1644 MPS: complex_out_mps 1645 1646- func: polar(Tensor abs, Tensor angle) -> Tensor 1647 variants: function 1648 dispatch: 1649 CompositeExplicitAutograd: polar 1650 1651- func: polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!) 1652 dispatch: 1653 CPU, CUDA: polar_out 1654 MPS: polar_out_mps 1655 1656- func: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor 1657 variants: function 1658 dispatch: 1659 CompositeExplicitAutograd: constant_pad_nd 1660 MPS: constant_pad_nd_mps 1661 autogen: constant_pad_nd.out 1662 tags: core 1663 1664- func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) 1665 variants: method 1666 manual_cpp_binding: True 1667 1668- func: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor 1669 dispatch: 1670 CompositeExplicitAutograd: convolution 1671 autogen: convolution.out 1672 tags: core 1673 1674- func: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) 1675 dispatch: 1676 CompositeExplicitAutograd, CUDA: convolution_backward 1677 autogen: convolution_backward.out 1678 tags: core 1679 1680- func: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor 1681 dispatch: 1682 CompositeExplicitAutograd: convolution_overrideable 1683 autogen: convolution_overrideable.out 1684 1685- func: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) 1686 dispatch: 1687 CompositeExplicitAutograd: convolution_backward_overrideable 1688 autogen: convolution_backward_overrideable.out 1689 1690- func: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor 1691 dispatch: 1692 CompositeExplicitAutograd: _convolution 1693 autogen: _convolution.out 1694 1695- func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor 1696 1697- func: _convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor 1698 dispatch: 1699 CompositeImplicitAutograd: _convolution_mode_symint 1700 1701- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) 1702 1703- func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor 1704 dispatch: 1705 CompositeImplicitAutograd: conv1d_symint 1706 1707- func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor 1708 dispatch: 1709 CompositeImplicitAutograd: conv2d_symint 1710 1711- func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor 1712 dispatch: 1713 CompositeImplicitAutograd: conv3d_symint 1714 1715- func: conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor 1716 cpp_no_default_args: ['bias', 'stride', 'padding'] 1717 dispatch: 1718 CompositeImplicitAutograd: conv1d_padding_symint 1719 1720- func: conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor 1721 cpp_no_default_args: ['bias', 'stride', 'padding'] 1722 dispatch: 1723 CompositeImplicitAutograd: conv2d_padding_symint 1724 1725- func: conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor 1726 cpp_no_default_args: ['bias', 'stride', 'padding'] 1727 dispatch: 1728 CompositeImplicitAutograd: conv3d_padding_symint 1729 1730- func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor 1731 dispatch: 1732 CompositeExplicitAutograd: conv_tbc 1733 autogen: conv_tbc.out 1734 1735- func: conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor) 1736 1737# NB: we inherit the goofy argument order from PyTorch torch.nn.functional 1738- func: conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor 1739 dispatch: 1740 CompositeImplicitAutograd: conv_transpose1d_symint 1741 1742- func: conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor 1743 dispatch: 1744 CompositeImplicitAutograd: conv_transpose2d_symint 1745 1746- func: conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor 1747 dispatch: 1748 CompositeImplicitAutograd: conv_transpose3d_symint 1749 1750- func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor 1751 variants: function 1752 dispatch: 1753 Meta: copy_meta 1754 CompositeExplicitAutogradNonFunctional: copy 1755 tags: core 1756 1757- func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) 1758 variants: method 1759 device_check: NoCheck 1760 device_guard: False 1761 dispatch: 1762 MkldnnCPU: copy_mkldnn_ 1763 SparseCPU, SparseCUDA: copy_sparse_wrapper_ 1764 CompositeExplicitAutograd: copy_ 1765 SparseCsrCPU, SparseCsrCUDA: copy_sparse_compressed_ 1766 NestedTensorCPU, NestedTensorCUDA: copy_nested_ 1767 autogen: copy.out 1768 1769- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor 1770 dispatch: 1771 MPS: _copy_from_mps 1772 autogen: _copy_from.out 1773 1774# We need this to be able to properly copy from a CPU to an XLA tensor with different sizes. 1775# See https://github.com/pytorch/xla/issues/2881 1776- func: _copy_from_and_resize(Tensor self, Tensor dst) -> Tensor 1777 dispatch: 1778 MPS: _copy_from_and_resize_mps 1779 autogen: _copy_from_and_resize.out 1780 1781- func: cos(Tensor self) -> Tensor 1782 device_check: NoCheck # TensorIterator 1783 variants: function, method 1784 structured_delegate: cos.out 1785 dispatch: 1786 NestedTensorCPU, NestedTensorCUDA: cos_nested 1787 tags: [core, pointwise] 1788 1789- func: cos_(Tensor(a!) self) -> Tensor(a!) 1790 device_check: NoCheck # TensorIterator 1791 variants: function, method 1792 structured_delegate: cos.out 1793 tags: pointwise 1794 1795- func: cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 1796 device_check: NoCheck # TensorIterator 1797 structured: True 1798 structured_inherits: TensorIteratorBase 1799 dispatch: 1800 CPU, CUDA: cos_out 1801 MPS: cos_out_mps 1802 tags: pointwise 1803 1804- func: cosh(Tensor self) -> Tensor 1805 device_check: NoCheck # TensorIterator 1806 variants: function, method 1807 structured_delegate: cosh.out 1808 tags: [core, pointwise] 1809 1810- func: cosh_(Tensor(a!) self) -> Tensor(a!) 1811 device_check: NoCheck # TensorIterator 1812 variants: function, method 1813 structured_delegate: cosh.out 1814 tags: pointwise 1815 1816- func: cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 1817 device_check: NoCheck # TensorIterator 1818 structured: True 1819 structured_inherits: TensorIteratorBase 1820 dispatch: 1821 CPU, CUDA: cosh_out 1822 MPS: cosh_out_mps 1823 tags: pointwise 1824 1825- func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor 1826 1827- func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor 1828 variants: function, method 1829 dispatch: 1830 CPU: count_nonzero_cpu 1831 CUDA: count_nonzero_cuda 1832 MPS: count_nonzero_mps 1833 autogen: count_nonzero.dim_IntList_out 1834 1835- func: count_nonzero(Tensor self, int? dim=None) -> Tensor 1836 variants: function, method 1837 dispatch: 1838 CompositeExplicitAutograd: count_nonzero 1839 autogen: count_nonzero.out 1840 1841- func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor 1842 variants: function, method 1843 1844- func: corrcoef(Tensor self) -> Tensor 1845 variants: function, method 1846 1847- func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid 1848 dispatch: 1849 CUDA: cudnn_affine_grid_generator_forward 1850 autogen: cudnn_affine_grid_generator.out 1851 1852# TODO: Why do I have to call this grad?! 1853- func: cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta 1854 dispatch: 1855 CUDA: cudnn_affine_grid_generator_backward 1856 autogen: cudnn_affine_grid_generator_backward.out 1857 1858- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) 1859 dispatch: 1860 CUDA: cudnn_batch_norm 1861 autogen: cudnn_batch_norm.out 1862 1863# NB: You can only use this if you used cudnn_batch_norm training=True 1864- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) 1865 dispatch: 1866 CUDA: cudnn_batch_norm_backward 1867 autogen: cudnn_batch_norm_backward.out 1868 1869- func: cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor 1870 dispatch: 1871 CUDA: cudnn_convolution 1872 1873- func: cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) 1874 dispatch: 1875 CUDA: cudnn_convolution_out 1876 1877- func: cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor 1878 dispatch: 1879 CUDA: cudnn_convolution_transpose 1880 autogen: cudnn_convolution_transpose.out 1881 1882- func: _mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor 1883 dispatch: 1884 MPS: _mps_convolution_transpose 1885 autogen: _mps_convolution_transpose.out 1886 1887- func: mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor) 1888 dispatch: 1889 MPS: mps_convolution_transpose_backward 1890 autogen: mps_convolution_transpose_backward.out 1891 1892- func: cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor 1893 dispatch: 1894 CUDA: cudnn_convolution_relu 1895 autogen: cudnn_convolution_relu.out 1896 1897- func: cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor 1898 dispatch: 1899 CUDA: cudnn_convolution_add_relu 1900 autogen: cudnn_convolution_add_relu.out 1901 1902# NB: input is special cased in a way I don't quite understand 1903- func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output 1904 dispatch: 1905 CUDA: cudnn_grid_sampler_forward 1906 autogen: cudnn_grid_sampler.out 1907 1908- func: cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid) 1909 dispatch: 1910 CUDA: cudnn_grid_sampler_backward 1911 autogen: cudnn_grid_sampler_backward.out 1912 1913- func: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) 1914 device_check: NoCheck # TensorIterator 1915 variants: function, method 1916 dispatch: 1917 CompositeExplicitAutograd: cummax 1918 1919- func: cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 1920 device_check: NoCheck # TensorIterator 1921 dispatch: 1922 CompositeExplicitAutograd: cummax_out 1923 1924- func: cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) 1925 device_check: NoCheck # TensorIterator 1926 variants: function, method 1927 1928- func: cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 1929 device_check: NoCheck # TensorIterator 1930 1931- func: _cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () 1932 variants: function 1933 dispatch: 1934 CPU: cummax_helper_cpu 1935 CUDA: cummax_helper_cuda 1936 1937- func: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) 1938 device_check: NoCheck # TensorIterator 1939 variants: function, method 1940 dispatch: 1941 CompositeExplicitAutograd: cummin 1942 1943- func: cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 1944 device_check: NoCheck # TensorIterator 1945 dispatch: 1946 CompositeExplicitAutograd: cummin_out 1947 1948- func: cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) 1949 device_check: NoCheck # TensorIterator 1950 variants: function, method 1951 1952- func: cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 1953 device_check: NoCheck # TensorIterator 1954 1955- func: _cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () 1956 variants: function 1957 dispatch: 1958 CPU: cummin_helper_cpu 1959 CUDA: cummin_helper_cuda 1960 1961- func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor 1962 variants: function 1963 device_check: NoCheck 1964 device_guard: False 1965 1966- func: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor 1967 structured_delegate: cumprod.out 1968 device_check: NoCheck # TensorIterator 1969 variants: function, method 1970 1971- func: cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) 1972 structured_delegate: cumprod.out 1973 variants: method 1974 1975- func: cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 1976 structured: True 1977 device_check: NoCheck # TensorIterator 1978 dispatch: 1979 CPU, CUDA: cumprod_out 1980 MPS: cumprod_out_mps 1981 1982- func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor 1983 device_check: NoCheck # TensorIterator 1984 variants: function, method 1985 1986- func: cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) 1987 variants: method 1988 1989- func: cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 1990 device_check: NoCheck # TensorIterator 1991 1992- func: cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor 1993 variants: function 1994 device_check: NoCheck 1995 device_guard: False 1996 1997- func: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor 1998 structured_delegate: cumsum.out 1999 device_check: NoCheck # TensorIterator 2000 variants: function, method 2001 tags: core 2002 2003- func: cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) 2004 structured_delegate: cumsum.out 2005 variants: method 2006 2007- func: cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 2008 structured: True 2009 device_check: NoCheck # TensorIterator 2010 dispatch: 2011 CPU, CUDA: cumsum_out 2012 MPS: cumsum_out_mps 2013 2014- func: cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor 2015 device_check: NoCheck # TensorIterator 2016 variants: function, method 2017 2018- func: cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) 2019 variants: method 2020 2021- func: cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 2022 device_check: NoCheck # TensorIterator 2023 2024- func: cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor 2025 2026- func: cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor 2027 2028- func: ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor 2029 2030# convenience function that converts to intlists for you 2031- func: ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor 2032 2033- func: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) 2034 dispatch: 2035 CPU: ctc_loss_cpu 2036 CUDA: ctc_loss_gpu 2037 Meta: ctc_loss_meta 2038 autogen: _ctc_loss.out 2039 tags: dynamic_output_shape # the shape of second output is data dependent 2040 2041- func: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) 2042 dispatch: 2043 CPU, CUDA: ctc_loss_tensor 2044 autogen: _ctc_loss.Tensor_out 2045 tags: dynamic_output_shape # the shape of second output is data dependent 2046 2047- func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor 2048 dispatch: 2049 CPU: ctc_loss_backward_cpu 2050 CUDA: ctc_loss_backward_gpu 2051 autogen: _ctc_loss_backward.out 2052 2053- func: _ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor 2054 dispatch: 2055 CPU, CUDA: ctc_loss_backward_tensor 2056 2057- func: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor 2058 variants: function, method 2059 dispatch: 2060 CompositeExplicitAutogradNonFunctional: diag_embed 2061 autogen: diag_embed.out 2062 2063- func: diagflat(Tensor self, int offset=0) -> Tensor 2064 variants: function, method 2065 2066- func: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) 2067 variants: function, method 2068 dispatch: 2069 CompositeExplicitAutograd: diagonal 2070 tags: core 2071 2072- func: linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a) 2073 python_module: linalg 2074 variants: function 2075 2076- func: diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a) 2077 variants: function, method 2078 2079- func: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor 2080 variants: function 2081 device_check: NoCheck 2082 device_guard: False 2083 dispatch: 2084 CompositeExplicitAutograd: diagonal_backward_symint 2085 autogen: diagonal_backward.out 2086 2087- func: fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!) 2088 variants: method 2089 2090- func: diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor 2091 variants: function, method 2092 2093- func: diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!) 2094 variants: function 2095 2096- func: gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[] 2097 variants: function 2098 2099- func: gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[] 2100 variants: function 2101 2102- func: gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[] 2103 variants: function 2104 2105- func: gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[] 2106 variants: function 2107 2108- func: gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[] 2109 variants: function 2110 2111- func: gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[] 2112 variants: function 2113 2114- func: gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[] 2115 variants: function 2116 2117- func: div.Tensor(Tensor self, Tensor other) -> Tensor 2118 device_check: NoCheck # TensorIterator 2119 variants: function, method 2120 structured_delegate: div.out 2121 dispatch: 2122 SparseCPU, SparseCUDA: div_sparse 2123 ZeroTensor: div_zerotensor 2124 NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Tensor 2125 tags: [core, pointwise] 2126 2127- func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 2128 device_check: NoCheck # TensorIterator 2129 variants: method 2130 structured_delegate: div.out 2131 dispatch: 2132 SparseCPU, SparseCUDA: div_sparse_ 2133 tags: pointwise 2134 2135- func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 2136 device_check: NoCheck # TensorIterator 2137 structured: True 2138 structured_inherits: TensorIteratorBase 2139 dispatch: 2140 CPU, CUDA: div_out 2141 MPS: div_out_mps 2142 SparseCPU, SparseCUDA: div_out_sparse_zerodim 2143 tags: pointwise 2144 2145- func: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor 2146 device_check: NoCheck # TensorIterator 2147 variants: function, method 2148 structured_delegate: div.out_mode 2149 dispatch: 2150 SparseCPU, SparseCUDA: div_sparse 2151 tags: [core, pointwise] 2152 2153- func: div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) 2154 device_check: NoCheck # TensorIterator 2155 variants: method 2156 structured_delegate: div.out_mode 2157 dispatch: 2158 SparseCPU, SparseCUDA: div_sparse_ 2159 tags: pointwise 2160 2161- func: div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) 2162 device_check: NoCheck # TensorIterator 2163 structured: True 2164 structured_inherits: TensorIteratorBase 2165 dispatch: 2166 CPU, CUDA: div_out_mode 2167 MPS: div_out_mode_mps 2168 SparseCPU, SparseCUDA: div_out_sparse_zerodim 2169 tags: pointwise 2170 2171# For C++ only, until we have conversion from C++ numbers to Tensor 2172- func: div.Scalar(Tensor self, Scalar other) -> Tensor 2173 device_check: NoCheck # TensorIterator 2174 variants: function, method 2175 dispatch: 2176 CompositeExplicitAutograd: div 2177 NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Scalar 2178 tags: [core, pointwise] 2179 2180- func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 2181 device_check: NoCheck # TensorIterator 2182 variants: method 2183 dispatch: 2184 CompositeExplicitAutograd: div_ 2185 autogen: div.Scalar_out 2186 tags: pointwise 2187 2188- func: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor 2189 variants: function, method 2190 dispatch: 2191 CompositeExplicitAutograd: div 2192 tags: [core, pointwise] 2193 2194- func: div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) 2195 variants: method 2196 dispatch: 2197 CompositeExplicitAutograd: div_ 2198 autogen: div.Scalar_mode_out 2199 tags: pointwise 2200 2201# divide, alias for div 2202- func: divide.Tensor(Tensor self, Tensor other) -> Tensor 2203 variants: function, method 2204 2205- func: divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 2206 variants: method 2207 2208- func: divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 2209 2210- func: divide.Scalar(Tensor self, Scalar other) -> Tensor 2211 variants: function, method 2212 2213- func: divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 2214 variants: method 2215 2216- func: divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor 2217 variants: function, method 2218 2219- func: divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) 2220 variants: method 2221 2222- func: divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) 2223 2224- func: divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor 2225 variants: function, method 2226 2227- func: divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) 2228 variants: method 2229 2230 # true_divide, an alias for div 2231- func: true_divide.Tensor(Tensor self, Tensor other) -> Tensor 2232 device_check: NoCheck # TensorIterator 2233 variants: function, method 2234 tags: pointwise 2235 2236- func: true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 2237 device_check: NoCheck # TensorIterator 2238 variants: method 2239 2240- func: true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 2241 device_check: NoCheck # TensorIterator 2242 2243- func: true_divide.Scalar(Tensor self, Scalar other) -> Tensor 2244 device_check: NoCheck # TensorIterator 2245 variants: function, method 2246 2247- func: true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 2248 device_check: NoCheck # TensorIterator 2249 variants: method 2250 2251- func: dot(Tensor self, Tensor tensor) -> Tensor 2252 variants: function, method 2253 dispatch: 2254 CPU: dot 2255 CUDA: dot_cuda 2256 MPS: dot_mps 2257 2258- func: dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) 2259 dispatch: 2260 CompositeExplicitAutograd: dot_out 2261 2262- func: vdot(Tensor self, Tensor other) -> Tensor 2263 variants: function, method 2264 dispatch: 2265 CPU: vdot 2266 CUDA: vdot_cuda 2267 2268- func: vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 2269 dispatch: 2270 CompositeExplicitAutograd: vdot_out 2271 2272- func: einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor 2273 2274- func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor 2275 dispatch: 2276 CompositeExplicitAutograd: embedding_symint 2277 NestedTensorCPU, NestedTensorCUDA: NestedTensor_embedding 2278 autogen: embedding.out 2279 tags: core 2280 2281- func: embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor 2282 dispatch: 2283 CompositeImplicitAutograd: embedding_backward_symint 2284 2285- func: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor 2286 dispatch: 2287 CPU: embedding_dense_backward_cpu 2288 CUDA: embedding_dense_backward_cuda 2289 MPS: embedding_dense_backward_mps 2290 autogen: embedding_dense_backward.out 2291 tags: core 2292 2293- func: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) 2294 dispatch: 2295 CPU: embedding_renorm_cpu_ 2296 CUDA: embedding_renorm_cuda_ 2297 autogen: embedding_renorm, embedding_renorm.out 2298 2299- func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor 2300 2301# NOTE [ embedding_bag Native Functions ] 2302# The `_embedding_bag.*` variants assume that input tensors except for `weight`, 2303# e.g. `indices` and `offsets` (and `offset2bag`), are contiguous. 2304# We really only need to enforce this for `_embedding_bag` (the forward) because 2305# the backward inputs are the same as forward ones. 2306# The above `embedding_bag` wrapper is created to achieve this, e.g., 2307# applying indices = indices.contiguous(). 2308# The backward functions apply a check that these input tensors are contiguous. 2309 2310 2311- func: _embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) 2312 dispatch: 2313 CPU: _embedding_bag_forward_only_cpu 2314 CUDA: _embedding_bag_forward_only_cuda 2315 autogen: _embedding_bag_forward_only.out 2316 2317- func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) 2318 2319# row_stack is the alias of vstack 2320- func: row_stack(Tensor[] tensors) -> Tensor 2321 2322- func: row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) 2323 2324- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) 2325 2326# To keep backward and forward compatibility, and to avoid ambiguity with the 2327# original signature above, scale_grad_by_freq, mode, sparse, 2328# per_sample_weights, and include_last_offset parameters do not have default 2329# values. Once the original signature is removed, default values can be added. 2330- func: embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) 2331 2332- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) 2333 dispatch: 2334 CPU: _embedding_bag_cpu 2335 CUDA: _embedding_bag_cuda 2336 autogen: _embedding_bag.out 2337 tags: core 2338 2339- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor 2340 dispatch: 2341 CompositeImplicitAutograd: _embedding_bag_backward_symint 2342 2343- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor 2344 dispatch: 2345 CompositeImplicitAutograd: _embedding_bag_sparse_backward_symint 2346 2347- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor 2348 dispatch: 2349 CPU: _embedding_bag_dense_backward_cpu 2350 CUDA: _embedding_bag_dense_backward_cuda 2351 autogen: _embedding_bag_dense_backward.out 2352 2353- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor 2354 dispatch: 2355 CPU: _embedding_bag_per_sample_weights_backward_cpu 2356 CUDA: _embedding_bag_per_sample_weights_backward_cuda 2357 autogen: _embedding_bag_per_sample_weights_backward.out 2358 2359- func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 2360 device_check: NoCheck 2361 device_guard: False 2362 dispatch: 2363 CompositeExplicitAutograd: empty_names 2364 autogen: empty.names_out 2365 2366- func: empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 2367 dispatch: 2368 CPU: empty_cpu 2369 CUDA: empty_cuda 2370 MPS: empty_mps 2371 Meta: empty_meta_symint 2372 MkldnnCPU: empty_mkldnn 2373 SparseCPU, SparseCUDA, SparseMeta: empty_sparse 2374 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: empty_sparse_compressed 2375 QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized 2376 tags: core 2377 2378- func: empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2379 dispatch: 2380 CompositeExplicitAutograd: empty_permuted_symint 2381 autogen: empty_permuted.out 2382 2383# We do not make new_empty a composite that calls into new_empty_strided, as the strided version 2384# is significantly more difficult to implement by different backends 2385- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2386 variants: method 2387 dispatch: 2388 CompositeExplicitAutograd: new_empty_symint 2389 autogen: new_empty.out 2390 2391- func: new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2392 variants: method 2393 dispatch: 2394 CompositeExplicitAutogradNonFunctional: new_empty_strided_symint 2395 autogen: new_empty_strided.out 2396 2397- func: new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2398 variants: method 2399 dispatch: 2400 # NB: Although this composite mutates on the inside, it is 2401 # non-differentiable so NonFunctional doesn't apply 2402 CompositeExplicitAutograd: new_full 2403 autogen: new_full.out 2404 2405- func: new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2406 variants: method 2407 dispatch: 2408 # NB: Although this composite mutates on the inside, it is 2409 # non-differentiable so NonFunctional doesn't apply 2410 CompositeExplicitAutograd: new_zeros 2411 autogen: new_zeros.out 2412 2413- func: new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2414 variants: method 2415 dispatch: 2416 # NB: Although this composite mutates on the inside, it is 2417 # non-differentiable so NonFunctional doesn't apply 2418 CompositeExplicitAutograd: new_ones 2419 autogen: new_ones.out 2420 2421# other overrides are to provide a more helpful error message that dtype is required 2422- func: _empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor 2423 dispatch: 2424 CPU: empty_affine_quantized_other_backends_stub 2425 QuantizedCPU, QuantizedCUDA: empty_affine_quantized 2426 autogen: _empty_affine_quantized.out 2427 2428# it's a factory function receiving a tensor argument, thus overriding explicitly 2429# other overrides are to provide a more helpful error message that dtype is required 2430- func: _empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor 2431 category_override: factory 2432 dispatch: 2433 CPU: empty_per_channel_affine_quantized_other_backends_stub 2434 QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized 2435 autogen: _empty_per_channel_affine_quantized.out 2436 2437- func: resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) 2438 use_const_ref_for_mutable_tensors: True 2439 variants: method 2440 device_check: NoCheck 2441 device_guard: False 2442 tags: [core, inplace_view] 2443 dispatch: 2444 Meta: resize__symint 2445 CPU: resize_ 2446 CUDA: resize_cuda_ 2447 MPS: resize_mps_ 2448 QuantizedCPU: quantized_resize_cpu_ 2449 SparseCsrCPU, SparseCsrCUDA: resize_sparse_csr_ 2450 autogen: resize, resize.out 2451 2452# This is a utility function to enable users to resize out tensor while registering kernels for out variants. 2453# Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration 2454# to make it easy to register out variants for ops. 2455- func: _resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!) 2456 use_const_ref_for_mutable_tensors: True 2457 variants: function 2458 dispatch: 2459 Meta: _resize_output_ 2460 autogen: _resize_output, _resize_output.out 2461 2462- func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 2463 category_override: factory 2464 variants: function 2465 dispatch: 2466 QuantizedCPU, QuantizedCUDA: empty_quantized 2467 autogen: empty_quantized.out 2468 2469- func: empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) 2470 device_check: NoCheck 2471 device_guard: False 2472 2473- func: empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 2474 device_check: NoCheck 2475 device_guard: False 2476 dispatch: 2477 CompositeExplicitAutograd: empty_like 2478 QuantizedCPU, QuantizedCUDA: empty_like_quantized 2479 SparseCPU, SparseCUDA, SparseMeta: empty_like_sparse_coo 2480 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: empty_like_sparse_csr 2481 NestedTensorCPU, NestedTensorCUDA: empty_like_nested 2482 autogen: empty_like.out 2483 2484- func: empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2485 dispatch: 2486 CPU: empty_strided_cpu 2487 CUDA: empty_strided_cuda 2488 MPS: empty_strided_mps 2489 Meta: empty_strided_meta_symint 2490 QuantizedCPU, QuantizedCUDA: empty_strided_unknown_quantized 2491 autogen: empty_strided.out 2492 tags: core 2493 2494- func: erf(Tensor self) -> Tensor 2495 device_check: NoCheck # TensorIterator 2496 structured_delegate: erf.out 2497 variants: function, method 2498 dispatch: 2499 SparseCPU, SparseCUDA: erf_sparse 2500 SparseCsrCPU, SparseCsrCUDA: erf_sparse_csr 2501 tags: [core, pointwise] 2502 2503- func: erf_(Tensor(a!) self) -> Tensor(a!) 2504 device_check: NoCheck # TensorIterator 2505 structured_delegate: erf.out 2506 variants: function, method 2507 dispatch: 2508 SparseCPU, SparseCUDA: erf_sparse_ 2509 SparseCsrCPU, SparseCsrCUDA: erf_sparse_csr_ 2510 tags: pointwise 2511 2512- func: erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 2513 device_check: NoCheck # TensorIterator 2514 structured: True 2515 structured_inherits: TensorIteratorBase 2516 dispatch: 2517 CPU, CUDA: erf_out 2518 MPS: erf_out_mps 2519 SparseCPU, SparseCUDA: erf_sparse_out 2520 SparseCsrCPU, SparseCsrCUDA: erf_sparse_csr_out 2521 tags: pointwise 2522 2523- func: erfc(Tensor self) -> Tensor 2524 device_check: NoCheck # TensorIterator 2525 structured_delegate: erfc.out 2526 variants: function, method 2527 tags: pointwise 2528 2529- func: erfc_(Tensor(a!) self) -> Tensor(a!) 2530 device_check: NoCheck # TensorIterator 2531 structured_delegate: erfc.out 2532 variants: function, method 2533 tags: pointwise 2534 2535- func: erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 2536 device_check: NoCheck # TensorIterator 2537 structured: True 2538 structured_inherits: TensorIteratorBase 2539 dispatch: 2540 CPU, CUDA: erfc_out 2541 tags: pointwise 2542 2543- func: exp(Tensor self) -> Tensor 2544 device_check: NoCheck # TensorIterator 2545 structured_delegate: exp.out 2546 variants: function, method 2547 tags: [core, pointwise] 2548 2549- func: exp_(Tensor(a!) self) -> Tensor(a!) 2550 device_check: NoCheck # TensorIterator 2551 structured_delegate: exp.out 2552 variants: function, method 2553 tags: pointwise 2554 2555- func: exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 2556 device_check: NoCheck # TensorIterator 2557 structured: True 2558 structured_inherits: TensorIteratorBase 2559 dispatch: 2560 CPU, CUDA: exp_out 2561 MPS: exp_out_mps 2562 tags: pointwise 2563 2564- func: exp2(Tensor self) -> Tensor 2565 structured_delegate: exp2.out 2566 variants: function, method 2567 tags: pointwise 2568 2569- func: exp2_(Tensor(a!) self) -> Tensor(a!) 2570 structured_delegate: exp2.out 2571 variants: function, method 2572 tags: pointwise 2573 2574- func: exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 2575 structured: True 2576 structured_inherits: TensorIteratorBase 2577 dispatch: 2578 CPU, CUDA: exp2_out 2579 MPS: exp2_out_mps 2580 tags: pointwise 2581 2582- func: expm1(Tensor self) -> Tensor 2583 device_check: NoCheck # TensorIterator 2584 structured_delegate: expm1.out 2585 variants: function, method 2586 dispatch: 2587 SparseCPU, SparseCUDA: expm1_sparse 2588 SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr 2589 tags: [core, pointwise] 2590 2591- func: expm1_(Tensor(a!) self) -> Tensor(a!) 2592 device_check: NoCheck # TensorIterator 2593 structured_delegate: expm1.out 2594 variants: function, method 2595 dispatch: 2596 SparseCPU, SparseCUDA: expm1_sparse_ 2597 SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_ 2598 tags: pointwise 2599 2600- func: expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 2601 device_check: NoCheck # TensorIterator 2602 structured: True 2603 structured_inherits: TensorIteratorBase 2604 dispatch: 2605 CPU, CUDA: expm1_out 2606 MPS: expm1_out_mps 2607 SparseCPU, SparseCUDA: expm1_sparse_out 2608 SparseCsrCPU, SparseCsrCUDA: expm1_sparse_csr_out 2609 tags: pointwise 2610 2611- func: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) 2612 variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. 2613 device_check: NoCheck 2614 device_guard: False 2615 dispatch: 2616 CompositeExplicitAutograd: expand 2617 tags: core 2618 2619- func: expand_as(Tensor(a) self, Tensor other) -> Tensor(a) 2620 variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. 2621 device_check: NoCheck 2622 device_guard: False 2623 2624# decomposes to eye.m 2625- func: eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2626 dispatch: 2627 CompositeExplicitAutograd: eye 2628 2629- func: eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2630 dispatch: 2631 CompositeExplicitAutograd: eye 2632 2633- func: eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) 2634 dispatch: 2635 CPU, Meta: eye_out_cpu 2636 CUDA: eye_out_cuda 2637 MPS: eye_out_mps 2638 2639- func: eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) 2640 dispatch: 2641 CPU, Meta: eye_out_cpu 2642 CUDA: eye_out_cuda 2643 MPS: eye_out_mps 2644 2645- func: flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) 2646 variants: function, method 2647 2648- func: flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a) 2649 variants: function, method 2650 2651- func: flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a) 2652 variants: function, method 2653 2654- func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a) 2655 variants: function, method 2656 2657- func: unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) 2658 variants: function, method 2659 dispatch: 2660 CompositeImplicitAutograd: unflatten_symint 2661 2662- func: unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) 2663 variants: function, method 2664 dispatch: 2665 CompositeImplicitAutograd: unflatten_dimname_symint 2666 2667- func: fill.Scalar(Tensor self, Scalar value) -> Tensor 2668 variants: function 2669 dispatch: 2670 CompositeExplicitAutograd: fill 2671 tags: core 2672 2673- func: fill.Tensor(Tensor self, Tensor value) -> Tensor 2674 variants: function 2675 dispatch: 2676 CompositeExplicitAutograd: fill 2677 2678- func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) 2679 device_check: NoCheck # TensorIterator 2680 variants: function, method 2681 dispatch: 2682 CPU, CUDA: fill_ 2683 MPS: fill_scalar_mps 2684 QuantizedCPU, QuantizedCUDA: fill_quantized_ 2685 Meta: fill_meta_ 2686 SparseCsrCPU, SparseCsrCUDA: fill_sparse_csr_ 2687 NestedTensorCPU, NestedTensorCUDA: fill_nested_ 2688 autogen: fill.Scalar_out 2689 2690- func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) 2691 device_check: NoCheck # TensorIterator 2692 variants: function, method 2693 dispatch: 2694 CPU, CUDA: fill_ 2695 MPS: fill_tensor_mps_ 2696 QuantizedCPU, QuantizedCUDA: fill_quantized_ 2697 Meta: fill_meta_ 2698 NestedTensorCPU, NestedTensorCUDA: fill_nested_ 2699 autogen: fill.Tensor_out 2700 2701- func: floor(Tensor self) -> Tensor 2702 device_check: NoCheck # TensorIterator 2703 structured_delegate: floor.out 2704 variants: function, method 2705 dispatch: 2706 SparseCPU, SparseCUDA: floor_sparse 2707 SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr 2708 tags: [core, pointwise] 2709 2710- func: floor_(Tensor(a!) self) -> Tensor(a!) 2711 device_check: NoCheck # TensorIterator 2712 structured_delegate: floor.out 2713 variants: function, method 2714 dispatch: 2715 SparseCPU, SparseCUDA: floor_sparse_ 2716 SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr_ 2717 tags: pointwise 2718 2719- func: floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 2720 device_check: NoCheck # TensorIterator 2721 structured: True 2722 structured_inherits: TensorIteratorBase 2723 dispatch: 2724 CPU, CUDA: floor_out 2725 MPS: floor_out_mps 2726 SparseCPU, SparseCUDA: floor_sparse_out 2727 SparseCsrCPU, SparseCsrCUDA: floor_sparse_csr_out 2728 tags: pointwise 2729 2730- func: floor_divide(Tensor self, Tensor other) -> Tensor 2731 device_check: NoCheck # TensorIterator 2732 variants: function, method 2733 dispatch: 2734 CPU, CUDA: floor_divide 2735 MPS: floor_divide_mps 2736 SparseCPU, SparseCUDA: floor_divide_sparse 2737 2738- func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 2739 device_check: NoCheck # TensorIterator 2740 variants: method 2741 dispatch: 2742 CPU, CUDA: floor_divide_ 2743 MPS: floor_divide_mps_ 2744 SparseCPU, SparseCUDA: floor_divide_sparse_ 2745 2746- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 2747 device_check: NoCheck # TensorIterator 2748 dispatch: 2749 CPU, CUDA: floor_divide_out 2750 MPS: floor_divide_out_mps 2751 SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim 2752 2753- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor 2754 device_check: NoCheck # TensorIterator 2755 variants: function, method 2756 dispatch: 2757 CompositeExplicitAutograd: floor_divide 2758 2759- func: floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 2760 device_check: NoCheck # TensorIterator 2761 variants: method 2762 dispatch: 2763 CompositeExplicitAutograd: floor_divide_ 2764 autogen: floor_divide.Scalar_out 2765 2766- func: frac(Tensor self) -> Tensor 2767 device_check: NoCheck # TensorIterator 2768 structured_delegate: frac.out 2769 variants: function, method 2770 dispatch: 2771 SparseCPU, SparseCUDA: frac_sparse 2772 SparseCsrCPU, SparseCsrCUDA: frac_sparse_csr 2773 tags: pointwise 2774 2775- func: frac_(Tensor(a!) self) -> Tensor(a!) 2776 device_check: NoCheck # TensorIterator 2777 structured_delegate: frac.out 2778 variants: function, method 2779 dispatch: 2780 SparseCPU, SparseCUDA: frac_sparse_ 2781 SparseCsrCPU, SparseCsrCUDA: frac_sparse_csr_ 2782 tags: pointwise 2783 2784- func: frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 2785 device_check: NoCheck # TensorIterator 2786 structured: True 2787 structured_inherits: TensorIteratorBase 2788 dispatch: 2789 CPU, CUDA: frac_out 2790 MPS: frac_out_mps 2791 SparseCPU, SparseCUDA: frac_sparse_out 2792 SparseCsrCPU, SparseCsrCUDA: frac_sparse_csr_out 2793 tags: pointwise 2794 2795- func: full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2796 device_check: NoCheck 2797 device_guard: False 2798 dispatch: 2799 CompositeExplicitAutograd: full 2800 autogen: full.names_out 2801 2802- func: full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2803 dispatch: 2804 CompositeExplicitAutograd: full 2805 tags: core 2806 2807- func: full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) 2808 dispatch: 2809 CompositeExplicitAutograd: full_out 2810 2811- func: full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 2812 dispatch: 2813 # NB: Although this composite mutates on the inside, it is 2814 # non-differentiable so NonFunctional doesn't apply 2815 CompositeExplicitAutograd: full_like 2816 autogen: full_like.out 2817 2818- func: from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2819 dispatch: 2820 CPU: from_file 2821 autogen: from_file.out 2822 2823- func: gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 2824 structured: True 2825 structured_inherits: TensorIteratorBase 2826 dispatch: 2827 CPU, CUDA: gcd_out 2828 tags: pointwise 2829 2830- func: gcd(Tensor self, Tensor other) -> Tensor 2831 structured_delegate: gcd.out 2832 variants: function, method 2833 tags: pointwise 2834 2835- func: gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!) 2836 structured_delegate: gcd.out 2837 variants: function, method 2838 2839- func: lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 2840 structured: True 2841 structured_inherits: TensorIteratorBase 2842 dispatch: 2843 CPU, CUDA: lcm_out 2844 tags: pointwise 2845 2846- func: lcm(Tensor self, Tensor other) -> Tensor 2847 structured_delegate: lcm.out 2848 variants: function, method 2849 tags: pointwise 2850 2851- func: lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!) 2852 structured_delegate: lcm.out 2853 variants: function, method 2854 2855# NOTE [ grid_sampler Native Functions ] 2856# `grid_sampler` is _supposed to_ do all the shape checking and then dispatch to 2857# one of `cudnn_grid_sampler`, `grid_sampler_2d`, or `grid_sampler_3d`, each of 2858# which has the corresponding backward defined as native functions as well. 2859# However, we do shape checking everywhere for now since each of the mentioned 2860# functions can be called directly, which will lead to crashes otherwise. 2861# See https://github.com/pytorch/pytorch/issues/73187 for more information. 2862# 2863# There is also _grid_sampler_2d_backward_cpu_fallback which is an 2864# implementation detail of grid_sampler_2d and is only exposed here for testing 2865# purposes. 2866# 2867# Additionally, arguments `padding_mode` and `interpolation_mode` are cast to 2868# enums defined in `native/GridSampler.h`. `cudnn_grid_sampler` doesn't take in 2869# `interpolation_mode` because it only supports Bilinear interpolation mode. 2870# Nor does it take in `align_corners` because it only supports the mode 2871# `align_corners = True`. 2872- func: grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor 2873 2874- func: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor 2875 dispatch: 2876 CPU, QuantizedCPU: grid_sampler_2d_cpu 2877 CUDA: grid_sampler_2d_cuda 2878 MPS: grid_sampler_2d_mps 2879 autogen: grid_sampler_2d.out 2880 tags: core 2881 2882# `grid_sampler_2d_backward` takes in `output_mask` to optimize performance for 2883# the case where `input` doesn't require gradient. Gradient for `grid` is always 2884# computed (only `output_mask[0]` is checked by the implementations). 2885- func: grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) 2886 dispatch: 2887 CPU: grid_sampler_2d_backward_cpu 2888 CUDA: grid_sampler_2d_backward_cuda 2889 autogen: grid_sampler_2d_backward.out 2890 2891# See NOTE [ grid_sample CPU fallback ] 2892- func: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor 2893 dispatch: 2894 CompositeExplicitAutograd: _grid_sampler_2d_cpu_fallback 2895 autogen: _grid_sampler_2d_cpu_fallback.out 2896 2897- func: _grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) 2898 2899- func: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor 2900 dispatch: 2901 CPU: grid_sampler_3d_cpu 2902 CUDA: grid_sampler_3d_cuda 2903 autogen: grid_sampler_3d.out 2904 2905# `grid_sampler_3d_backward` takes in `output_mask` to optimize performance for 2906# the case where `input` doesn't require gradient. Gradient for `grid` is always 2907# computed (only `output_mask[0]` is checked by the implementations). 2908- func: grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) 2909 dispatch: 2910 CPU: grid_sampler_3d_backward_cpu 2911 CUDA: grid_sampler_3d_backward_cuda 2912 autogen: grid_sampler_3d_backward.out 2913 2914- func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2915 dispatch: 2916 CompositeExplicitAutograd: hann_window 2917 autogen: hann_window.out 2918 2919- func: hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2920 dispatch: 2921 CompositeExplicitAutograd: hann_window 2922 autogen: hann_window.periodic_out 2923 2924- func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2925 dispatch: 2926 CompositeExplicitAutograd: hamming_window 2927 autogen: hamming_window.out 2928 2929- func: hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2930 dispatch: 2931 CompositeExplicitAutograd: hamming_window 2932 autogen: hamming_window.periodic_out 2933 2934- func: hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2935 dispatch: 2936 CompositeExplicitAutograd: hamming_window 2937 autogen: hamming_window.periodic_alpha_out 2938 2939- func: hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2940 dispatch: 2941 CompositeExplicitAutograd: hamming_window 2942 autogen: hamming_window.periodic_alpha_beta_out 2943 2944- func: kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2945 dispatch: 2946 CompositeExplicitAutograd: kaiser_window 2947 autogen: kaiser_window.out 2948 2949- func: kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2950 dispatch: 2951 CompositeExplicitAutograd: kaiser_window 2952 autogen: kaiser_window.periodic_out 2953 2954- func: kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 2955 dispatch: 2956 CompositeExplicitAutograd: kaiser_window 2957 autogen: kaiser_window.beta_out 2958 2959- func: hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor 2960 2961- func: group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor 2962 2963- func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) 2964 dispatch: 2965 CPU, CUDA: native_group_norm 2966 CompositeExplicitAutograd: math_group_norm 2967 autogen: native_group_norm.out 2968 tags: core 2969 2970- func: native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) 2971 dispatch: 2972 CPU, CUDA: native_group_norm_backward 2973 autogen: native_group_norm_backward.out 2974 tags: core 2975 2976# Real to complex forward FFT 2977- func: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor 2978 variants: function 2979 dispatch: 2980 CPU: _fft_r2c_mkl 2981 CUDA: _fft_r2c_cufft 2982 MPS: _fft_r2c_mps 2983 2984- func: _fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!) 2985 variants: function 2986 dispatch: 2987 CPU: _fft_r2c_mkl_out 2988 CUDA: _fft_r2c_cufft_out 2989 MPS: _fft_r2c_mps_out 2990 2991# Complex to real inverse FFT 2992- func: _fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor 2993 variants: function 2994 dispatch: 2995 CPU: _fft_c2r_mkl 2996 CUDA: _fft_c2r_cufft 2997 MPS: _fft_c2r_mps 2998 2999- func: _fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) 3000 variants: function 3001 dispatch: 3002 CPU: _fft_c2r_mkl_out 3003 CUDA: _fft_c2r_cufft_out 3004 MPS: _fft_c2r_mps_out 3005 3006# Standard complex to complex FFT (forward or backward) 3007- func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor 3008 variants: function 3009 dispatch: 3010 CPU: _fft_c2c_mkl 3011 CUDA: _fft_c2c_cufft 3012 MPS: _fft_c2c_mps 3013 3014- func: _fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) 3015 variants: function 3016 dispatch: 3017 CPU: _fft_c2c_mkl_out 3018 CUDA: _fft_c2c_cufft_out 3019 MPS: _fft_c2c_mps_out 3020 3021- func: _validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> () 3022 device_check: NoCheck 3023 variants: function 3024 dispatch: 3025 CPU: _validate_compressed_sparse_indices_cpu 3026 CUDA: _validate_compressed_sparse_indices_cuda 3027 3028- func: _cufft_get_plan_cache_size(DeviceIndex device_index) -> int 3029 3030- func: _cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int 3031 3032- func: _cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> () 3033 3034- func: _cufft_clear_plan_cache(DeviceIndex device_index) -> () 3035 3036- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor 3037 device_check: NoCheck # TensorIterator 3038 structured_delegate: index.Tensor_out 3039 variants: function, method 3040 dispatch: 3041 QuantizedCPU: quantized_index 3042 tags: [core, dynamic_output_shape] 3043 # NB: This function is special-cased in tools/autograd/gen_variable_type.py 3044 # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: 3045 # - Tensor Tensor::index(ArrayRef<TensorIndex> indices) 3046 # - Tensor Tensor::index(std::initializer_list<TensorIndex> indices) 3047 3048- func: index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!) 3049 device_check: NoCheck 3050 structured: True 3051 structured_inherits: TensorIteratorBase 3052 precomputed: 3053 - indices -> DimVector sizes, DimVector strides 3054 dispatch: 3055 CPU, CUDA, MPS: index_out 3056 3057# Used by inductor to signal indexing without bounds checks 3058# Note that we don't support boolean indexing, to avoid dynamic output shapes 3059- func: _unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor 3060 variants: function 3061 dispatch: 3062 CompositeExplicitAutograd: _unsafe_index 3063 3064- func: index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!) 3065 structured: True 3066 variants: function 3067 precomputed: 3068 - dim -> int dim 3069 dispatch: 3070 CPU, CUDA: index_copy_out 3071 3072- func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) 3073 variants: method 3074 structured_delegate: index_copy.out 3075 3076- func: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor 3077 variants: function, method 3078 structured_delegate: index_copy.out 3079 3080- func: index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!) 3081 variants: method 3082 3083- func: index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor 3084 variants: function, method 3085 3086- func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) 3087 device_check: NoCheck # delegate to _index_put_impl_, which leverages TensorIterator 3088 variants: function, method 3089 dispatch: 3090 CompositeExplicitAutograd: index_put_ 3091 autogen: index_put.out 3092 # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: 3093 # - Tensor & Tensor::index_put_(ArrayRef<TensorIndex> indices, Tensor const & rhs) 3094 # - Tensor & Tensor::index_put_(ArrayRef<TensorIndex> indices, Scalar v) 3095 # - Tensor & Tensor::index_put_(std::initializer_list<TensorIndex> indices, Tensor const & rhs) 3096 # - Tensor & Tensor::index_put_(std::initializer_list<TensorIndex> indices, Scalar v) 3097 3098- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor 3099 device_check: NoCheck # delegate to _index_put_impl_ after clone, which leverages TensorIterator 3100 variants: function, method 3101 dispatch: 3102 CompositeExplicitAutograd: index_put 3103 tags: core 3104 3105- func: _unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor 3106 device_check: NoCheck # delegate to _index_put_impl_ after clone, which leverages TensorIterator 3107 variants: function 3108 dispatch: 3109 CompositeExplicitAutograd: _unsafe_index_put 3110 3111- func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) 3112 device_check: NoCheck # TensorIterator 3113 variants: function 3114 dispatch: 3115 CPU, CUDA, MPS: _index_put_impl_ 3116 QuantizedCPU: _index_put_impl_quantized_cpu_ 3117 QuantizedCUDA: _index_put_impl_quantized_cuda_ 3118 autogen: _index_put_impl, _index_put_impl.out 3119 3120- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor 3121 variants: function 3122 3123- func: isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor 3124 variants: function, method 3125 3126- func: isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) 3127 variants: function 3128 structured: True 3129 dispatch: 3130 CPU, CUDA: isin_Tensor_Tensor_out 3131 MPS: isin_Tensor_Tensor_out_mps 3132 3133- func: isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor 3134 variants: function 3135 structured_delegate: isin.Tensor_Tensor_out 3136 3137- func: isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) 3138 variants: function 3139 structured: True 3140 dispatch: 3141 CPU, CUDA: isin_Tensor_Scalar_out 3142 3143- func: isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor 3144 variants: function 3145 structured_delegate: isin.Tensor_Scalar_out 3146 3147- func: isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) 3148 variants: function 3149 structured: True 3150 dispatch: 3151 CPU, CUDA: isin_Scalar_Tensor_out 3152 3153- func: isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor 3154 variants: function 3155 structured_delegate: isin.Scalar_Tensor_out 3156 3157- func: isnan(Tensor self) -> Tensor 3158 variants: function, method 3159 device_check: NoCheck 3160 device_guard: False 3161 dispatch: 3162 CPU, CUDA, MPS: isnan 3163 SparseCPU, SparseCUDA: isnan_sparse 3164 SparseCsrCPU, SparseCsrCUDA: isnan_sparse_csr 3165 autogen: isnan.out 3166 tags: [core, pointwise] 3167 3168- func: is_distributed(Tensor self) -> bool 3169 variants: function, method 3170 device_check: NoCheck 3171 device_guard: False 3172 3173- func: is_floating_point(Tensor self) -> bool 3174 variants: function, method 3175 device_check: NoCheck 3176 device_guard: False 3177 manual_cpp_binding: True 3178 3179- func: is_complex(Tensor self) -> bool 3180 variants: function, method 3181 device_check: NoCheck 3182 device_guard: False 3183 manual_cpp_binding: True 3184 3185- func: is_conj(Tensor self) -> bool 3186 variants: function, method 3187 device_guard: False 3188 manual_cpp_binding: True 3189 3190- func: _is_zerotensor(Tensor self) -> bool 3191 variants: function, method 3192 device_guard: False 3193 manual_cpp_binding: True 3194 3195- func: is_neg(Tensor self) -> bool 3196 variants: function, method 3197 device_guard: False 3198 manual_cpp_binding: True 3199 3200- func: isreal(Tensor self) -> Tensor 3201 variants: function, method 3202 3203- func: is_nonzero(Tensor self) -> bool 3204 variants: function, method 3205 device_check: NoCheck 3206 device_guard: False 3207 3208- func: is_same_size(Tensor self, Tensor other) -> bool 3209 variants: function, method 3210 device_check: NoCheck 3211 device_guard: False 3212 dispatch: 3213 NestedTensorCPU, NestedTensorCUDA: nested_is_same_size 3214 CompositeExplicitAutograd: is_same_size 3215 3216- func: is_signed(Tensor self) -> bool 3217 variants: function, method 3218 device_check: NoCheck 3219 device_guard: False 3220 manual_cpp_binding: True 3221 3222- func: is_inference(Tensor self) -> bool 3223 variants: function, method 3224 device_check: NoCheck 3225 device_guard: False 3226 manual_cpp_binding: True 3227 3228- func: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor 3229 3230- func: kron(Tensor self, Tensor other) -> Tensor 3231 variants: function, method 3232 3233- func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 3234 3235- func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) 3236 variants: function, method 3237 dispatch: 3238 CompositeExplicitAutograd: kthvalue 3239 3240- func: kthvalue.values(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 3241 dispatch: 3242 CPU: kthvalue_out_cpu 3243 CUDA: kthvalue_out_cuda 3244 3245- func: kthvalue.dimname(Tensor self, int k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) 3246 variants: function, method 3247 3248- func: kthvalue.dimname_out(Tensor self, int k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 3249 3250- func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor 3251 dispatch: 3252 CompositeImplicitAutograd: layer_norm_symint 3253 3254- func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) 3255 dispatch: 3256 CPU: layer_norm_cpu 3257 CUDA: layer_norm_cuda 3258 MPS: layer_norm_mps 3259 CompositeExplicitAutograd: math_native_layer_norm 3260 NestedTensorCPU, NestedTensorCUDA: nested_layer_norm 3261 autogen: native_layer_norm.out 3262 tags: core 3263 3264- func: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) 3265 dispatch: 3266 CPU: layer_norm_backward_cpu 3267 CUDA: layer_norm_backward_cuda 3268 MPS: layer_norm_backward_mps 3269 NestedTensorCPU, NestedTensorCUDA: layer_norm_backward_nested 3270 autogen: native_layer_norm_backward.out 3271 tags: core 3272 3273- func: rms_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor 3274 3275- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor 3276 variants: function, method 3277 dispatch: 3278 CompositeExplicitAutograd: nan_to_num 3279 SparseCPU, SparseCUDA: nan_to_num_sparse 3280 tags: pointwise 3281 3282- func: nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) 3283 variants: function, method 3284 dispatch: 3285 CompositeExplicitAutograd: nan_to_num_ 3286 SparseCPU, SparseCUDA: nan_to_num_sparse_ 3287 tags: pointwise 3288 3289- func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) 3290 dispatch: 3291 CPU, CUDA: nan_to_num_out 3292 MPS: nan_to_num_out_mps 3293 SparseCPU, SparseCUDA: nan_to_num_sparse_out 3294 tags: pointwise 3295 3296- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor 3297 python_module: nn 3298 dispatch: 3299 CompositeImplicitAutograd: linear 3300 NestedTensorCPU, NestedTensorCUDA: nested_linear 3301 MPS: _mps_linear 3302 3303- func: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) 3304 dispatch: 3305 NestedTensorCPU, NestedTensorCUDA: nested_linear_backward 3306 MPS: mps_linear_backward 3307 autogen: linear_backward.out 3308 3309- func: linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) 3310 python_module: nn 3311 dispatch: 3312 CompositeExplicitAutograd: linear_out 3313 3314- func: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor 3315 python_module: nn 3316 dispatch: 3317 MkldnnCPU: mkldnn_linear 3318 autogen: mkldnn_linear.out 3319 3320- func: mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor 3321 dispatch: 3322 MkldnnCPU: mkldnn_linear_backward_input 3323 autogen: mkldnn_linear_backward_input.out 3324 3325- func: mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor) 3326 dispatch: 3327 MkldnnCPU: mkldnn_linear_backward_weights 3328 autogen: mkldnn_linear_backward_weights.out 3329 3330- func: mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) 3331 dispatch: 3332 MkldnnCPU: mkldnn_linear_backward 3333 autogen: mkldnn_linear_backward.out 3334 3335- func: _cslt_compress(Tensor input) -> Tensor 3336 dispatch: 3337 CUDA: _cslt_compress 3338 3339- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0) -> Tensor 3340 dispatch: 3341 CUDA: _cslt_sparse_mm 3342 3343- func: _cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int 3344 dispatch: 3345 CUDA: _cslt_sparse_mm_search 3346 3347- func: _sparse_semi_structured_tile(Tensor input, str algorithm="", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor) 3348 dispatch: 3349 CUDA: _sparse_semi_structured_tile 3350 3351- func: _sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor) 3352 dispatch: 3353 CUDA: _sparse_semi_structured_apply 3354 3355- func: _sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor 3356 dispatch: 3357 CUDA: _sparse_semi_structured_apply_dense 3358 3359# DEPRECATED: Use torch.__sparse_semi_structured_mm/torch._sparse_semi_structured_addmm instead 3360- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor 3361 dispatch: 3362 CUDA: _sparse_semi_structured_linear 3363 3364- func: _sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor 3365 dispatch: 3366 CUDA: _sparse_semi_structured_mm 3367 3368- func: _sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor 3369 dispatch: 3370 CUDA: _sparse_semi_structured_addmm 3371 3372- func: _mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor 3373 dispatch: 3374 CUDA: _mixed_dtypes_linear 3375 3376- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor 3377 3378- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor 3379 3380- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int) 3381 3382- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor 3383 3384- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor 3385 3386- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor 3387 3388- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor 3389 3390- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor 3391 3392- func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor 3393 variants: function, method 3394 3395- func: ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!) 3396 variants: function, method 3397 tags: pointwise 3398 3399- func: ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 3400 tags: pointwise 3401 3402- func: linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 3403 dispatch: 3404 CompositeExplicitAutograd: linspace 3405 3406- func: linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 3407 category_override: factory 3408 dispatch: 3409 CompositeExplicitAutograd: linspace 3410 3411- func: linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 3412 category_override: factory 3413 dispatch: 3414 CompositeExplicitAutograd: linspace 3415 3416- func: linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 3417 category_override: factory 3418 dispatch: 3419 CompositeExplicitAutograd: linspace 3420 3421- func: linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) 3422 dispatch: 3423 CPU, Meta: linspace_out 3424 CUDA: linspace_cuda_out 3425 MPS: linspace_out_mps 3426 3427- func: linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) 3428 category_override: factory 3429 dispatch: 3430 CompositeExplicitAutograd: linspace_out 3431 3432- func: linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) 3433 category_override: factory 3434 dispatch: 3435 CompositeExplicitAutograd: linspace_out 3436 3437- func: linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) 3438 category_override: factory 3439 dispatch: 3440 CompositeExplicitAutograd: linspace_out 3441 3442- func: log(Tensor self) -> Tensor 3443 device_check: NoCheck # TensorIterator 3444 structured_delegate: log.out 3445 variants: function, method 3446 tags: [core, pointwise] 3447 3448- func: log_(Tensor(a!) self) -> Tensor(a!) 3449 device_check: NoCheck # TensorIterator 3450 structured_delegate: log.out 3451 variants: function, method 3452 tags: pointwise 3453 3454- func: log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 3455 device_check: NoCheck # TensorIterator 3456 structured: True 3457 structured_inherits: TensorIteratorBase 3458 dispatch: 3459 CPU, CUDA: log_out 3460 MPS: log_out_mps 3461 tags: pointwise 3462 3463- func: log10(Tensor self) -> Tensor 3464 device_check: NoCheck # TensorIterator 3465 structured_delegate: log10.out 3466 variants: function, method 3467 tags: [core, pointwise] 3468 3469- func: log10_(Tensor(a!) self) -> Tensor(a!) 3470 device_check: NoCheck # TensorIterator 3471 structured_delegate: log10.out 3472 variants: function, method 3473 tags: pointwise 3474 3475- func: log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 3476 device_check: NoCheck # TensorIterator 3477 structured: True 3478 structured_inherits: TensorIteratorBase 3479 dispatch: 3480 CPU, CUDA: log10_out 3481 MPS: log10_out_mps 3482 tags: pointwise 3483 3484- func: log1p(Tensor self) -> Tensor 3485 device_check: NoCheck # TensorIterator 3486 structured_delegate: log1p.out 3487 variants: function, method 3488 dispatch: 3489 SparseCPU, SparseCUDA: log1p_sparse 3490 SparseCsrCPU, SparseCsrCUDA: log1p_sparse_csr 3491 tags: [core, pointwise] 3492 3493- func: log1p_(Tensor(a!) self) -> Tensor(a!) 3494 device_check: NoCheck # TensorIterator 3495 structured_delegate: log1p.out 3496 variants: function, method 3497 dispatch: 3498 SparseCPU, SparseCUDA: log1p_sparse_ 3499 SparseCsrCPU, SparseCsrCUDA: log1p_sparse_csr_ 3500 tags: pointwise 3501 3502- func: log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 3503 device_check: NoCheck # TensorIterator 3504 structured: True 3505 structured_inherits: TensorIteratorBase 3506 dispatch: 3507 CPU, CUDA: log1p_out 3508 MPS: log1p_out_mps 3509 SparseCPU, SparseCUDA: log1p_sparse_out 3510 SparseCsrCPU, SparseCsrCUDA: log1p_sparse_csr_out 3511 tags: pointwise 3512 3513- func: log2(Tensor self) -> Tensor 3514 device_check: NoCheck # TensorIterator 3515 structured_delegate: log2.out 3516 variants: function, method 3517 tags: [core, pointwise] 3518 3519- func: log2_(Tensor(a!) self) -> Tensor(a!) 3520 device_check: NoCheck # TensorIterator 3521 structured_delegate: log2.out 3522 variants: function, method 3523 tags: pointwise 3524 3525- func: log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 3526 device_check: NoCheck # TensorIterator 3527 structured: True 3528 structured_inherits: TensorIteratorBase 3529 dispatch: 3530 CPU, CUDA: log2_out 3531 MPS: log2_out_mps 3532 tags: pointwise 3533 3534- func: logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 3535 structured: True 3536 structured_inherits: TensorIteratorBase 3537 dispatch: 3538 CPU, CUDA: logaddexp_out 3539 MPS: logaddexp_out_mps 3540 tags: pointwise 3541 3542- func: logaddexp(Tensor self, Tensor other) -> Tensor 3543 variants: method, function 3544 structured_delegate: logaddexp.out 3545 tags: pointwise 3546 3547- func: logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 3548 structured: True 3549 structured_inherits: TensorIteratorBase 3550 dispatch: 3551 CPU, CUDA: logaddexp2_out 3552 MPS: logaddexp2_out_mps 3553 tags: pointwise 3554 3555- func: logaddexp2(Tensor self, Tensor other) -> Tensor 3556 variants: method, function 3557 structured_delegate: logaddexp2.out 3558 tags: pointwise 3559 3560- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor 3561 device_check: NoCheck # TensorIterator 3562 structured_delegate: xlogy.OutTensor 3563 variants: function, method 3564 tags: pointwise 3565 3566- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor 3567 device_check: NoCheck # TensorIterator 3568 variants: function 3569 dispatch: 3570 CompositeExplicitAutograd: xlogy 3571 tags: pointwise 3572 3573- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor 3574 device_check: NoCheck # TensorIterator 3575 variants: function, method 3576 dispatch: 3577 CompositeExplicitAutograd: xlogy 3578 tags: pointwise 3579 3580# xlogy: inplace variant 3581- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 3582 device_check: NoCheck # TensorIterator 3583 variants: function, method 3584 structured_delegate: xlogy.OutTensor 3585 tags: pointwise 3586 3587- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) 3588 device_check: NoCheck # TensorIterator 3589 variants: function, method 3590 dispatch: 3591 CompositeExplicitAutograd: xlogy_ 3592 3593# xlogy: out variant 3594- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 3595 device_check: NoCheck # TensorIterator 3596 structured: True 3597 structured_inherits: TensorIteratorBase 3598 variants: function 3599 dispatch: 3600 CPU, CUDA: xlogy_out 3601 MPS: xlogy_out_mps 3602 tags: pointwise 3603 3604- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 3605 device_check: NoCheck # TensorIterator 3606 variants: function 3607 dispatch: 3608 CompositeExplicitAutograd: xlogy_out 3609 tags: pointwise 3610 3611- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 3612 device_check: NoCheck # TensorIterator 3613 variants: function 3614 dispatch: 3615 CompositeExplicitAutograd: xlogy_out 3616 tags: pointwise 3617 3618- func: logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 3619 dispatch: 3620 CompositeExplicitAutograd: logspace 3621 3622- func: logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 3623 category_override: factory 3624 dispatch: 3625 CompositeExplicitAutograd: logspace 3626 3627- func: logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 3628 category_override: factory 3629 dispatch: 3630 CompositeExplicitAutograd: logspace 3631 3632- func: logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 3633 category_override: factory 3634 dispatch: 3635 CompositeExplicitAutograd: logspace 3636 3637- func: logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) 3638 dispatch: 3639 CPU, Meta: logspace_out 3640 CUDA: logspace_cuda_out 3641 3642- func: logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) 3643 category_override: factory 3644 dispatch: 3645 CompositeExplicitAutograd: logspace_out 3646 3647- func: logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) 3648 category_override: factory 3649 dispatch: 3650 CompositeExplicitAutograd: logspace_out 3651 3652- func: logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) 3653 category_override: factory 3654 dispatch: 3655 CompositeExplicitAutograd: logspace_out 3656 3657# log_softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. 3658- func: log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor 3659 variants: function, method 3660 3661- func: log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) 3662 variants: function 3663 dispatch: 3664 CompositeExplicitAutograd: log_softmax_out 3665 3666- func: log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor 3667 variants: function, method 3668 3669- func: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor 3670 structured_delegate: _log_softmax.out 3671 tags: core 3672 3673- func: _log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) 3674 structured: True 3675 dispatch: 3676 CPU: log_softmax_cpu_out 3677 CUDA: log_softmax_cuda_out 3678 MPS: log_softmax_mps_out 3679 3680- func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor 3681 structured_delegate: _log_softmax_backward_data.out 3682 3683- func: _log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!) 3684 structured: True 3685 dispatch: 3686 CPU: log_softmax_backward_cpu_out 3687 CUDA: log_softmax_backward_cuda_out 3688 MPS: log_softmax_backward_mps_out 3689 3690- func: _logcumsumexp(Tensor self, int dim) -> Tensor 3691 dispatch: 3692 CPU: _logcumsumexp_cpu 3693 CUDA: _logcumsumexp_cuda 3694 3695- func: _logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) 3696 dispatch: 3697 CPU: _logcumsumexp_out_cpu 3698 CUDA: _logcumsumexp_out_cuda 3699 3700- func: logcumsumexp(Tensor self, int dim) -> Tensor 3701 variants: function, method 3702 dispatch: 3703 CompositeExplicitAutograd: logcumsumexp 3704 3705- func: logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) 3706 dispatch: 3707 CompositeExplicitAutograd: logcumsumexp_out 3708 3709- func: logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor 3710 variants: function, method 3711 3712- func: logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) 3713 3714- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor 3715 device_check: NoCheck # TensorIterator 3716 variants: function, method 3717 dispatch: 3718 CompositeExplicitAutograd: logsumexp 3719 3720- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 3721 device_check: NoCheck # TensorIterator 3722 dispatch: 3723 # calls squeeze 3724 CompositeExplicitAutogradNonFunctional: logsumexp_out 3725 3726- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor 3727 device_check: NoCheck # TensorIterator 3728 variants: function, method 3729 3730- func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 3731 device_check: NoCheck # TensorIterator 3732 3733- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor 3734 3735- func: matmul(Tensor self, Tensor other) -> Tensor 3736 variants: function, method 3737 dispatch: 3738 CompositeImplicitAutograd: matmul 3739 NestedTensorCPU, NestedTensorCUDA: matmul_nested 3740 3741- func: matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor) 3742 dispatch: 3743 NestedTensorCPU, NestedTensorCUDA: matmul_backward_nested 3744 autogen: matmul_backward.out 3745 3746- func: matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 3747 dispatch: 3748 CompositeImplicitAutograd: matmul_out 3749 NestedTensorCPU, NestedTensorCUDA: matmul_out_nested 3750 3751# Alias to linalg.matrix_power 3752- func: matrix_power(Tensor self, int n) -> Tensor 3753 variants: function, method 3754 3755# Alias to linalg.matrix_power 3756- func: matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) 3757 3758# Alias to linalg.matrix_exp 3759- func: matrix_exp(Tensor self) -> Tensor 3760 variants: function, method 3761 3762# This function should be deprecated in favor of differential_analytic_matrix_function in FunctionsManual.cpp 3763- func: matrix_exp_backward(Tensor self, Tensor grad) -> Tensor 3764 3765# DEPRECATED: Use torch.aminmax instead 3766- func: _aminmax(Tensor self) -> (Tensor, Tensor) 3767 dispatch: 3768 CPU, CUDA: _aminmax_all 3769 autogen: _aminmax.out 3770 3771# DEPRECATED: Use torch.aminmax instead 3772- func: _aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor) 3773 dispatch: 3774 CPU, CUDA: _aminmax 3775 autogen: _aminmax.dim_out 3776 3777- func: aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max) 3778 device_check: NoCheck # TensorIterator 3779 structured_delegate: aminmax.out 3780 variants: function, method 3781 3782- func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) 3783 device_check: NoCheck # TensorIterator 3784 structured: True 3785 dispatch: 3786 CPU, CUDA: aminmax_out 3787 MPS: aminmax_out_mps 3788 3789- func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor 3790 dispatch: 3791 CPU, CUDA: _compute_linear_combination 3792 3793- func: _compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!) 3794 dispatch: 3795 CPU, CUDA: _compute_linear_combination_out 3796 3797- func: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) 3798 device_check: NoCheck # TensorIterator 3799 structured_delegate: max.dim_max 3800 variants: function, method 3801 dispatch: 3802 QuantizedCPU, QuantizedCUDA: qmax 3803 tags: core 3804 3805- func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) 3806 device_check: NoCheck # TensorIterator 3807 structured: True 3808 precomputed: 3809 - dim -> int dim 3810 dispatch: 3811 CPU, CUDA: max_out 3812 MPS: max_out_mps 3813 3814- func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) 3815 device_check: NoCheck # TensorIterator 3816 variants: function, method 3817 3818- func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) 3819 device_check: NoCheck # TensorIterator 3820 3821- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor 3822 variants: function 3823 device_check: NoCheck 3824 device_guard: False 3825 dispatch: 3826 CompositeImplicitAutograd: value_selecting_reduction_backward_symint 3827 3828- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor 3829 variants: function, method 3830 structured_delegate: amax.out 3831 tags: core 3832 3833- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 3834 structured: True 3835 dispatch: 3836 CPU, CUDA: amax_out 3837 MPS: amax_out_mps 3838 3839# Return: (Tensor output, Tensor indices) 3840- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) 3841 3842- func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor 3843 3844- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor 3845 dispatch: 3846 CompositeImplicitAutograd: max_pool2d 3847 MPS: mps_max_pool2d 3848 3849- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor 3850 dispatch: 3851 MPS: mps_max_pool2d_backward 3852 autogen: max_pool2d_backward.out 3853 3854- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor 3855 dispatch: 3856 MkldnnCPU: mkldnn_max_pool2d 3857 autogen: mkldnn_max_pool2d.out 3858 3859- func: mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor 3860 dispatch: 3861 MkldnnCPU: mkldnn_max_pool2d_backward 3862 autogen: mkldnn_max_pool2d_backward.out 3863 3864- func: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor 3865 dispatch: 3866 MkldnnCPU: mkldnn_max_pool3d 3867 autogen: mkldnn_max_pool3d.out 3868 3869- func: mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor 3870 dispatch: 3871 MkldnnCPU: mkldnn_max_pool3d_backward 3872 autogen: mkldnn_max_pool3d_backward.out 3873 3874- func: quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor 3875 dispatch: 3876 QuantizedCPU: quantized_max_pool1d 3877 autogen: quantized_max_pool1d.out 3878 3879- func: quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor 3880 dispatch: 3881 QuantizedCPU: quantized_max_pool2d 3882 QuantizedCUDA: quantized_max_pool2d_cudnn 3883 autogen: quantized_max_pool2d.out 3884 3885- func: quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor 3886 dispatch: 3887 QuantizedCPU: quantized_max_pool3d 3888 autogen: quantized_max_pool3d.out 3889 3890- func: max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor 3891 3892# The CPU and GPU dispatch variants are named weirdly here because otherwise there 3893# are namespacing issues in C++ 3894- func: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor 3895 device_check: NoCheck # TensorIterator 3896 variants: function, method 3897 dispatch: 3898 CompositeExplicitAutograd: mean 3899 tags: core 3900 3901# For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this. 3902# FIXME: fix CI jobs and re-enable this 3903#- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 3904# device_check: NoCheck # TensorIterator 3905# dispatch: 3906# CompositeExplicitAutograd: mean_dtype_out 3907 3908- func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 3909 structured_delegate: mean.out 3910 device_check: NoCheck # TensorIterator 3911 variants: function, method 3912 dispatch: 3913 QuantizedCPU: mean_quantized_cpu 3914 tags: core 3915 3916- func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 3917 structured: True 3918 device_check: NoCheck # TensorIterator 3919 dispatch: 3920 CPU, CUDA: mean_out 3921 MPS: mean_out_mps 3922 QuantizedCPU: mean_out_quantized_cpu 3923 3924- func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 3925 device_check: NoCheck # TensorIterator 3926 variants: function, method 3927 3928- func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 3929 device_check: NoCheck # TensorIterator 3930 3931- func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 3932 device_check: NoCheck # Composite 3933 variants: function, method 3934 3935- func: nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 3936 device_check: NoCheck # Composite 3937 3938- func: median(Tensor self) -> Tensor 3939 variants: function, method 3940 dispatch: 3941 CPU: median_cpu 3942 CUDA: median_cuda 3943 MPS: median_mps 3944 autogen: median.out 3945 3946- func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) 3947 variants: function, method 3948 dispatch: 3949 CompositeExplicitAutograd: median 3950 3951- func: median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 3952 dispatch: 3953 CPU: median_out_cpu 3954 CUDA: median_out_cuda 3955 MPS: median_out_mps 3956 3957- func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) 3958 variants: function, method 3959 3960- func: median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 3961 3962- func: nanmedian(Tensor self) -> Tensor 3963 variants: function, method 3964 dispatch: 3965 CPU: nanmedian_cpu 3966 CUDA: nanmedian_cuda 3967 autogen: nanmedian.out 3968 3969- func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) 3970 variants: function, method 3971 dispatch: 3972 CompositeExplicitAutograd: nanmedian 3973 3974- func: nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 3975 dispatch: 3976 CPU: nanmedian_out_cpu 3977 CUDA: nanmedian_out_cuda 3978 3979- func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) 3980 variants: function, method 3981 3982- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 3983 3984- func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) 3985 device_check: NoCheck # TensorIterator 3986 structured_delegate: min.dim_min 3987 variants: function, method 3988 dispatch: 3989 QuantizedCPU, QuantizedCUDA: qmin 3990 tags: core 3991 3992- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) 3993 device_check: NoCheck # TensorIterator 3994 structured: True 3995 precomputed: 3996 - dim -> int dim 3997 dispatch: 3998 CPU, CUDA: min_out 3999 MPS: min_out_mps 4000 4001- func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) 4002 device_check: NoCheck # TensorIterator 4003 variants: function, method 4004 4005- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) 4006 device_check: NoCheck # TensorIterator 4007 4008- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor 4009 variants: function, method 4010 structured_delegate: amin.out 4011 tags: core 4012 4013- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 4014 structured: True 4015 dispatch: 4016 CPU, CUDA: amin_out 4017 MPS: amin_out_mps 4018 4019# TODO: Add this function to MPS dispatch key so that we avoid declaring it in 4020# native_functions.yaml 4021# https://github.com/pytorch/pytorch/issues/77394 4022- func: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor 4023 dispatch: 4024 MPS: _mps_convolution 4025 autogen: _mps_convolution.out 4026 4027- func: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) 4028 dispatch: 4029 MPS: mps_convolution_backward 4030 autogen: mps_convolution_backward.out 4031 4032- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor 4033 dispatch: 4034 CompositeExplicitAutograd: mkldnn_convolution 4035 autogen: mkldnn_convolution.out 4036 4037- func: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor) 4038 dispatch: 4039 CPU: mkldnn_rnn_layer 4040 MkldnnCPU: mkldnn_rnn_layer 4041 autogen: mkldnn_rnn_layer.out 4042 4043- func: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) 4044 dispatch: 4045 CPU: mkldnn_rnn_layer_backward 4046 autogen: mkldnn_rnn_layer_backward.out 4047 4048- func: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) 4049 dispatch: 4050 CUDA: miopen_batch_norm 4051 autogen: miopen_batch_norm.out 4052 4053- func: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) 4054 dispatch: 4055 CUDA: miopen_batch_norm_backward 4056 autogen: miopen_batch_norm_backward.out 4057 4058- func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor 4059 dispatch: 4060 CUDA: miopen_convolution 4061 autogen: miopen_convolution.out 4062 4063- func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor 4064 dispatch: 4065 CUDA: miopen_convolution_transpose 4066 autogen: miopen_convolution_transpose.out 4067 4068- func: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor 4069 dispatch: 4070 CUDA: miopen_depthwise_convolution 4071 autogen: miopen_depthwise_convolution.out 4072 4073- func: miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor 4074 dispatch: 4075 CUDA: miopen_convolution_relu 4076 4077- func: miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor 4078 dispatch: 4079 CUDA: miopen_convolution_add_relu 4080 4081- func: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) 4082 dispatch: 4083 CUDA: miopen_rnn 4084 autogen: miopen_rnn.out 4085 tags: nondeterministic_seeded 4086 4087 4088- func: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) 4089 dispatch: 4090 CUDA: miopen_rnn_backward 4091 autogen: miopen_rnn_backward.out 4092 4093- func: mm(Tensor self, Tensor mat2) -> Tensor 4094 structured_delegate: mm.out 4095 variants: function, method 4096 dispatch: 4097 SparseCPU, SparseCUDA: _sparse_mm 4098 SparseCsrCPU, SparseCsrCUDA: _sparse_csr_mm 4099 tags: core 4100 4101- func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) 4102 structured: True 4103 dispatch: 4104 CPU: mm_out_cpu 4105 CUDA: mm_out_cuda 4106 MPS: mm_out_mps 4107 SparseCPU, SparseCUDA: _sparse_mm_out 4108 SparseCsrCPU, SparseCsrCUDA: _sparse_csr_mm_out 4109 4110- func: _int_mm(Tensor self, Tensor mat2) -> Tensor 4111 dispatch: 4112 CPU: _int_mm_cpu 4113 CUDA: _int_mm_cuda 4114 4115- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) 4116 dispatch: 4117 CPU: _int_mm_out_cpu 4118 CUDA: _int_mm_out_cuda 4119 4120- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor 4121 dispatch: 4122 CPU: _convert_weight_to_int4pack_cpu 4123 CUDA: _convert_weight_to_int4pack_cuda 4124 4125- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor 4126 dispatch: 4127 CPU: _weight_int4pack_mm_cpu 4128 MPS: _weight_int4pack_mm_mps 4129 CUDA: _weight_int4pack_mm_cuda 4130 4131- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor 4132 dispatch: 4133 CPU: _weight_int8pack_mm_cpu 4134 MPS: _weight_int8pack_mm_mps 4135 4136- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor 4137 python_module: sparse 4138 4139- func: _sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor 4140 python_module: sparse 4141 4142- func: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor 4143 dispatch: 4144 SparseCPU: sparse_sparse_matmul_cpu 4145 SparseCUDA: sparse_sparse_matmul_cuda 4146 autogen: _sparse_sparse_matmul.out 4147 4148- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) 4149 variants: function, method 4150 dispatch: 4151 CPU, CUDA: mode 4152 4153- func: mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 4154 dispatch: 4155 CompositeExplicitAutograd: mode_out 4156 4157- func: mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) 4158 variants: function, method 4159 4160- func: mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 4161 4162- func: mul.Tensor(Tensor self, Tensor other) -> Tensor 4163 device_check: NoCheck # TensorIterator 4164 structured_delegate: mul.out 4165 variants: function, method 4166 dispatch: 4167 SparseCPU, SparseCUDA: mul_sparse 4168 SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr 4169 MkldnnCPU: mkldnn_mul 4170 ZeroTensor: mul_zerotensor 4171 NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Tensor 4172 tags: [core, pointwise] 4173 4174- func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 4175 device_check: NoCheck # TensorIterator 4176 structured_delegate: mul.out 4177 variants: method 4178 dispatch: 4179 SparseCPU, SparseCUDA: mul_sparse_ 4180 SparseCsrCPU, SparseCsrCUDA: mul_sparse_csr_ 4181 MkldnnCPU: mkldnn_mul_ 4182 NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul__Tensor 4183 tags: pointwise 4184 4185- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 4186 device_check: NoCheck # TensorIterator 4187 structured: True 4188 structured_inherits: TensorIteratorBase 4189 dispatch: 4190 CPU, CUDA: mul_out 4191 MPS: mul_out_mps 4192 SparseCPU: mul_out_sparse_cpu 4193 SparseCUDA: mul_out_sparse_cuda 4194 SparseCsrCPU, SparseCsrCUDA: mul_out_sparse_csr 4195 MkldnnCPU: mkldnn_mul_out 4196 tags: pointwise 4197 # For C++ only, until we have conversion from C++ numbers to Tensor 4198 4199- func: mul.Scalar(Tensor self, Scalar other) -> Tensor 4200 device_check: NoCheck # TensorIterator 4201 variants: function, method 4202 dispatch: 4203 CompositeExplicitAutograd: mul 4204 SparseCsrCPU, SparseCsrCUDA: mul_scalar_sparse_csr 4205 NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul_Scalar 4206 tags: [core, pointwise] 4207 4208- func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 4209 device_check: NoCheck # TensorIterator 4210 variants: method 4211 dispatch: 4212 CompositeExplicitAutograd: mul_ 4213 SparseCsrCPU, SparseCsrCUDA: mul__scalar_sparse_csr 4214 NestedTensorCPU, NestedTensorCUDA: NestedTensor_mul__Scalar 4215 autogen: mul.Scalar_out 4216 tags: pointwise 4217# multiply, alias for mul 4218 4219- func: multiply.Tensor(Tensor self, Tensor other) -> Tensor 4220 variants: function, method 4221 4222- func: multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 4223 variants: method 4224 4225- func: multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 4226 4227- func: multiply.Scalar(Tensor self, Scalar other) -> Tensor 4228 variants: function, method 4229 4230- func: multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 4231 variants: method 4232 4233- func: mv(Tensor self, Tensor vec) -> Tensor 4234 variants: function, method 4235 dispatch: 4236 CompositeExplicitAutograd: mv 4237 SparseCPU, SparseCUDA: mv_sparse 4238 4239- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) 4240 dispatch: 4241 CompositeExplicitAutograd: mv_out 4242 4243- func: mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) 4244 dispatch: 4245 CPU, CUDA: mvlgamma_out 4246 tags: pointwise 4247 4248- func: mvlgamma(Tensor self, int p) -> Tensor 4249 device_check: NoCheck # TensorIterator 4250 variants: function, method 4251 dispatch: 4252 CompositeExplicitAutograd: mvlgamma 4253 tags: pointwise 4254 4255- func: mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) 4256 device_check: NoCheck # TensorIterator 4257 variants: method 4258 dispatch: 4259 CompositeExplicitAutograd: mvlgamma_ 4260 tags: pointwise 4261 4262- func: narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor 4263 variants: function, method 4264 dispatch: 4265 CPU: narrow_copy_dense_cpu 4266 SparseCPU, SparseCUDA: narrow_copy_sparse 4267 CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint 4268 tags: view_copy 4269 4270- func: narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) 4271 dispatch: 4272 CPU: narrow_copy_dense_cpu_out 4273 4274- func: narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) 4275 variants: function, method 4276 device_check: NoCheck 4277 device_guard: False 4278 dispatch: 4279 CompositeImplicitAutograd: narrow_symint 4280 NestedTensorCPU, NestedTensorCUDA: narrow_nested_symint 4281 4282- func: narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) 4283 variants: function, method 4284 device_check: NoCheck 4285 device_guard: False 4286 dispatch: 4287 CompositeImplicitAutograd: narrow_tensor_symint 4288 4289- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) 4290 dispatch: 4291 CPU: batch_norm_cpu 4292 CUDA: batch_norm_cuda 4293 MPS: batch_norm_mps 4294 MkldnnCPU: mkldnn_batch_norm 4295 4296- func: native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) 4297 dispatch: 4298 CUDA: batch_norm_cuda_out 4299 MPS: batch_norm_mps_out 4300 CPU: batch_norm_cpu_out 4301 4302# TODO: In 2 weeks, we should make native_batch_norm composite implicit so that this correct schema percolates correctly through our dispatching 4303- func: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) 4304 dispatch: 4305 CPU: _batch_norm_legit_cpu 4306 CUDA: _batch_norm_legit_cuda 4307 MPS: _batch_norm_legit_mps 4308 MkldnnCPU: _mkldnn_batch_norm_legit 4309 autogen: _native_batch_norm_legit_functional 4310 tags: core 4311 4312# HACK: identical to _native_batch_norm_legit, but training is known to be False, 4313# So we known that running stats will not be mutated. 4314# The real fix here is batch norm consolidation. 4315- func: _native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor) 4316 dispatch: 4317 CompositeExplicitAutograd: _batch_norm_legit_no_training 4318 autogen: _native_batch_norm_legit_no_training.out 4319 tags: core 4320 4321- func: _native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!)) 4322 dispatch: 4323 CPU: _batch_norm_legit_cpu_out 4324 CUDA: _batch_norm_legit_cuda_out 4325 MPS: _batch_norm_legit_mps_out 4326 4327- func: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) 4328 dispatch: 4329 CPU: _batch_norm_legit_no_stats_cpu 4330 CUDA: _batch_norm_legit_no_stats_cuda 4331 MPS: _batch_norm_legit_no_stats_mps 4332 MkldnnCPU: _mkldnn_batch_norm_legit_no_stats 4333 tags: core 4334 4335- func: _native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) 4336 dispatch: 4337 CPU: _batch_norm_legit_no_stats_cpu_out 4338 CUDA: _batch_norm_legit_no_stats_cuda_out 4339 MPS: _batch_norm_legit_no_stats_mps_out 4340 4341- func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) 4342 dispatch: 4343 CUDA: batch_norm_stats_cuda 4344 autogen: batch_norm_stats.out 4345 4346- func: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor 4347 dispatch: 4348 CUDA: batch_norm_elemt_cuda 4349 4350- func: batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) 4351 dispatch: 4352 CUDA: batch_norm_elemt_cuda_out 4353 4354# for backward compatibility 4355- func: batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor) 4356 dispatch: 4357 CUDA: batch_norm_gather_stats_cuda 4358 autogen: batch_norm_gather_stats.out 4359 4360- func: batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor) 4361 dispatch: 4362 CUDA: batch_norm_gather_stats_with_counts_cuda 4363 autogen: batch_norm_gather_stats_with_counts.out 4364 4365- func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) 4366 dispatch: 4367 CPU: batch_norm_backward_cpu 4368 CUDA: batch_norm_backward_cuda 4369 MPS: batch_norm_backward_mps 4370 MkldnnCPU: mkldnn_batch_norm_backward 4371 autogen: native_batch_norm_backward.out 4372 4373- func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) 4374 dispatch: 4375 CUDA: batch_norm_backward_reduce_cuda 4376 autogen: batch_norm_backward_reduce.out 4377 4378- func: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor 4379 dispatch: 4380 CUDA: batch_norm_backward_elemt_cuda 4381 autogen: batch_norm_backward_elemt.out 4382 4383- func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor) 4384 dispatch: 4385 CPU: batch_norm_update_stats_cpu 4386 CUDA: batch_norm_update_stats_cuda 4387 autogen: batch_norm_update_stats.out 4388 4389- func: is_vulkan_available() -> bool 4390 4391- func: _nnpack_available() -> bool 4392 4393- func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor 4394 variants: function 4395 dispatch: 4396 CompositeExplicitAutograd: _nnpack_spatial_convolution 4397 autogen: _nnpack_spatial_convolution.out 4398 4399- func: ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4400 device_check: NoCheck 4401 device_guard: False 4402 dispatch: 4403 CompositeExplicitAutograd: ones 4404 autogen: ones.names_out 4405 4406- func: ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4407 dispatch: 4408 CompositeExplicitAutograd: ones 4409 4410- func: ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) 4411 dispatch: 4412 CompositeExplicitAutograd: ones_out 4413 4414- func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 4415 dispatch: 4416 # NB: Although this composite mutates on the inside, it is 4417 # non-differentiable so NonFunctional doesn't apply 4418 CompositeExplicitAutograd: ones_like 4419 NestedTensorCPU, NestedTensorCUDA: ones_like 4420 autogen: ones_like.out 4421 4422- func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor 4423 4424- func: cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor 4425 4426- func: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor 4427 dispatch: 4428 CompositeExplicitAutograd: _euclidean_dist 4429 autogen: _euclidean_dist.out 4430 4431- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor 4432 dispatch: 4433 CPU, CUDA: _cdist_forward 4434 MPS: _cdist_forward_mps 4435 autogen: _cdist_forward.out 4436 tags: core 4437 4438- func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor 4439 dispatch: 4440 CPU, CUDA: _cdist_backward 4441 autogen: _cdist_backward.out 4442 4443- func: pdist(Tensor self, float p=2) -> Tensor 4444 4445- func: _pdist_forward(Tensor self, float p=2) -> Tensor 4446 dispatch: 4447 CPU, CUDA: _pdist_forward 4448 autogen: _pdist_forward.out 4449 tags: core 4450 4451- func: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor 4452 dispatch: 4453 CPU, CUDA: _pdist_backward 4454 autogen: _pdist_backward.out 4455 4456- func: cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor 4457 variants: function 4458 4459- func: permute(Tensor(a) self, int[] dims) -> Tensor(a) 4460 variants: function, method 4461 dispatch: 4462 CompositeExplicitAutograd: permute 4463 MPS: permute_mps 4464 SparseCPU, SparseCUDA: permute_sparse_coo 4465 tags: core 4466 4467- func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) 4468 variants: function, method 4469 4470- func: movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a) 4471 variants: function, method 4472 4473# moveaxis, alias for movedim 4474- func: moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) 4475 variants: function, method 4476 4477- func: moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a) 4478 variants: function, method 4479 4480# Only exposed from C++ -- in Python, 4481# we expose it as an attribute `T`, not a function. 4482# 4483# I'd like to name this "T" in C++ too, but 4484# calling a native function "T" causes undefined 4485# behavior on Windows, for reasons I don't understand 4486# (maybe related to capital letter collation somehow...) 4487- func: numpy_T(Tensor(a) self) -> Tensor(a) 4488 variants: method 4489 4490# Exposed on Python as an attribute 'H' 4491- func: matrix_H(Tensor(a) self) -> Tensor(a) 4492 variants: method 4493 4494# Exposed on Python as an attribute 'mT' 4495- func: mT(Tensor(a) self) -> Tensor(a) 4496 variants: method 4497 4498# Exposed on Python as an attribute 'mH' 4499- func: mH(Tensor(a) self) -> Tensor(a) 4500 variants: method 4501 4502- func: adjoint(Tensor(a) self) -> Tensor(a) 4503 variants: function, method 4504 4505- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor 4506 dispatch: 4507 CPU: pixel_shuffle_cpu 4508 MPS: pixel_shuffle_mps 4509 CompositeExplicitAutogradNonFunctional: math_pixel_shuffle 4510 autogen: pixel_shuffle.out 4511 4512- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor 4513 dispatch: 4514 CPU: pixel_unshuffle_cpu 4515 MPS: pixel_unshuffle_mps 4516 CompositeExplicitAutogradNonFunctional: math_pixel_unshuffle 4517 autogen: pixel_unshuffle.out 4518 4519- func: channel_shuffle(Tensor self, SymInt groups) -> Tensor 4520 dispatch: 4521 CPU, CUDA: channel_shuffle 4522 QuantizedCPU: channel_shuffle_quantized_cpu 4523 autogen: channel_shuffle.out 4524 4525- func: native_channel_shuffle(Tensor self, SymInt groups) -> Tensor 4526 dispatch: 4527 CPU: channel_shuffle_cpu 4528 CompositeImplicitAutograd: math_channel_shuffle 4529 4530- func: is_pinned(Tensor self, Device? device=None) -> bool 4531 variants: method 4532 dispatch: 4533 NestedTensorCUDA, CUDA: is_pinned_cuda 4534 MPS: is_pinned_mps 4535 CompositeExplicitAutograd: is_pinned_default 4536 4537# TODO: add a copy kwarg that guarantees that the tensor is put into fresh 4538# pinned memory 4539- func: pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a) 4540 variants: method 4541 4542# Unlike pin_memory, this is guaranteed to give a new non-aliasing tensor 4543- func: _pin_memory(Tensor self, Device? device=None) -> Tensor 4544 dispatch: 4545 CUDA: _pin_memory_cuda 4546 MPS: _pin_memory_mps 4547 NestedTensorCUDA, NestedTensorCPU: _pin_memory_nested 4548 autogen: _pin_memory.out 4549 4550- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor 4551 variants: function, method 4552 4553- func: poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor 4554 variants: function 4555 4556- func: rad2deg(Tensor self) -> Tensor 4557 variants: function, method 4558 dispatch: 4559 CompositeExplicitAutograd: rad2deg 4560 SparseCPU, SparseCUDA: rad2deg_sparse 4561 SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr 4562 4563- func: rad2deg_(Tensor(a!) self) -> Tensor(a!) 4564 variants: function, method 4565 dispatch: 4566 CompositeExplicitAutograd: rad2deg_ 4567 SparseCPU, SparseCUDA: rad2deg_sparse_ 4568 SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr_ 4569 4570- func: rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 4571 dispatch: 4572 CompositeExplicitAutograd: rad2deg_out 4573 SparseCPU, SparseCUDA: rad2deg_sparse_out 4574 SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr_out 4575 4576- func: deg2rad(Tensor self) -> Tensor 4577 variants: function, method 4578 dispatch: 4579 CompositeExplicitAutograd: deg2rad 4580 SparseCPU, SparseCUDA: deg2rad_sparse 4581 SparseCsrCPU, SparseCsrCUDA: deg2rad_sparse_csr 4582 tags: pointwise 4583 4584- func: deg2rad_(Tensor(a!) self) -> Tensor(a!) 4585 variants: function, method 4586 dispatch: 4587 CompositeExplicitAutograd: deg2rad_ 4588 SparseCPU, SparseCUDA: deg2rad_sparse_ 4589 SparseCsrCPU, SparseCsrCUDA: deg2rad_sparse_csr_ 4590 tags: pointwise 4591 4592- func: deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 4593 dispatch: 4594 CompositeExplicitAutograd: deg2rad_out 4595 SparseCPU, SparseCUDA: deg2rad_sparse_out 4596 SparseCsrCPU, SparseCsrCUDA: deg2rad_sparse_csr_out 4597 tags: pointwise 4598 4599- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4600 dispatch: 4601 CompositeExplicitAutograd: scalar_tensor 4602 autogen: scalar_tensor.out 4603 tags: core 4604 4605- func: rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4606 device_check: NoCheck 4607 device_guard: False 4608 dispatch: 4609 CompositeExplicitAutograd: rand 4610 autogen: rand.names_out 4611 tags: nondeterministic_seeded 4612 4613- func: rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4614 device_check: NoCheck 4615 device_guard: False 4616 tags: nondeterministic_seeded 4617 dispatch: 4618 CompositeExplicitAutograd: rand 4619 autogen: rand.generator_with_names_out 4620 4621- func: rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4622 tags: [core, nondeterministic_seeded] 4623 dispatch: 4624 CompositeExplicitAutograd: rand 4625 4626- func: rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4627 tags: nondeterministic_seeded 4628 dispatch: 4629 CompositeExplicitAutograd: rand 4630 4631- func: rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) 4632 tags: nondeterministic_seeded 4633 dispatch: 4634 CompositeExplicitAutograd: rand_out 4635 4636- func: rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) 4637 tags: nondeterministic_seeded 4638 4639- func: rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 4640 tags: nondeterministic_seeded 4641 dispatch: 4642 # NB: Although this composite mutates on the inside, it is 4643 # non-differentiable so NonFunctional doesn't apply 4644 CompositeExplicitAutograd: rand_like 4645 autogen: rand_like.out 4646 4647- func: randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4648 tags: nondeterministic_seeded 4649 dispatch: 4650 CompositeExplicitAutograd: randint 4651 4652- func: randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4653 tags: nondeterministic_seeded 4654 dispatch: 4655 CompositeExplicitAutograd: randint 4656 4657- func: randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4658 tags: nondeterministic_seeded 4659 dispatch: 4660 CompositeExplicitAutograd: randint 4661 4662- func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4663 tags: nondeterministic_seeded 4664 dispatch: 4665 CompositeExplicitAutograd: randint 4666 4667- func: randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) 4668 tags: nondeterministic_seeded 4669 dispatch: 4670 CompositeExplicitAutograd: randint_out 4671 4672- func: randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) 4673 tags: nondeterministic_seeded 4674 dispatch: 4675 CompositeExplicitAutograd: randint_out 4676 4677- func: randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) 4678 tags: nondeterministic_seeded 4679 dispatch: 4680 CompositeExplicitAutograd: randint_out 4681 4682- func: randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) 4683 tags: nondeterministic_seeded 4684 dispatch: 4685 CompositeExplicitAutograd: randint_out 4686 4687- func: randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 4688 tags: nondeterministic_seeded 4689 dispatch: 4690 # NB: Although this composite mutates on the inside, it is 4691 # non-differentiable so NonFunctional doesn't apply 4692 CompositeExplicitAutograd: randint_like 4693 autogen: randint_like.out 4694 4695- func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 4696 tags: nondeterministic_seeded 4697 dispatch: 4698 # NB: Although this composite mutates on the inside, it is 4699 # non-differentiable so NonFunctional doesn't apply 4700 CompositeExplicitAutograd: randint_like 4701 autogen: randint_like.low_dtype_out 4702 4703- func: randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4704 tags: [core, nondeterministic_seeded] 4705 dispatch: 4706 CompositeExplicitAutograd: randn 4707 4708- func: randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4709 tags: nondeterministic_seeded 4710 dispatch: 4711 CompositeExplicitAutograd: randn 4712 4713- func: randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4714 tags: nondeterministic_seeded 4715 device_check: NoCheck 4716 device_guard: False 4717 dispatch: 4718 CompositeExplicitAutograd: randn 4719 autogen: randn.names_out 4720 4721- func: randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4722 tags: nondeterministic_seeded 4723 device_check: NoCheck 4724 device_guard: False 4725 dispatch: 4726 CompositeExplicitAutograd: randn 4727 autogen: randn.generator_with_names_out 4728 4729- func: randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) 4730 tags: nondeterministic_seeded 4731 4732- func: randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) 4733 tags: nondeterministic_seeded 4734 4735- func: randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 4736 tags: nondeterministic_seeded 4737 dispatch: 4738 # NB: Although this composite mutates on the inside, it is 4739 # non-differentiable so NonFunctional doesn't apply 4740 CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like 4741 autogen: randn_like.out 4742 4743- func: randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4744 tags: [core, nondeterministic_seeded] 4745 dispatch: 4746 CompositeExplicitAutograd: randperm 4747 4748- func: randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4749 tags: nondeterministic_seeded 4750 dispatch: 4751 CompositeExplicitAutograd: randperm 4752 4753- func: randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) 4754 tags: nondeterministic_seeded 4755 dispatch: 4756 CompositeExplicitAutograd: randperm_out 4757 4758- func: randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) 4759 tags: nondeterministic_seeded 4760 dispatch: 4761 CPU: randperm_out_cpu 4762 CUDA: randperm_out_cuda 4763 MPS: randperm_out_mps 4764 4765- func: range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4766 dispatch: 4767 CompositeExplicitAutograd: range 4768 4769- func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 4770 dispatch: 4771 CompositeExplicitAutograd: range 4772 4773- func: range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!) 4774 dispatch: 4775 CompositeExplicitAutograd: range_out_no_step 4776 4777- func: range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) 4778 dispatch: 4779 CPU, Meta: range_out 4780 CUDA: range_cuda_out 4781 MPS: range_mps_out 4782 cpp_no_default_args: ['step'] 4783 4784- func: ravel(Tensor(a) self) -> Tensor(a) 4785 variants: function, method 4786 4787- func: reciprocal(Tensor self) -> Tensor 4788 device_check: NoCheck # TensorIterator 4789 structured_delegate: reciprocal.out 4790 variants: function, method 4791 tags: [core, pointwise] 4792 4793- func: reciprocal_(Tensor(a!) self) -> Tensor(a!) 4794 device_check: NoCheck # TensorIterator 4795 structured_delegate: reciprocal.out 4796 variants: function, method 4797 tags: pointwise 4798 4799- func: reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 4800 device_check: NoCheck # TensorIterator 4801 structured: True 4802 structured_inherits: TensorIteratorBase 4803 dispatch: 4804 CPU, CUDA: reciprocal_out 4805 MPS: reciprocal_out_mps 4806 tags: pointwise 4807 4808- func: neg(Tensor self) -> Tensor 4809 device_check: NoCheck # TensorIterator 4810 structured_delegate: neg.out 4811 variants: function, method 4812 dispatch: 4813 SparseCPU, SparseCUDA: neg_sparse 4814 SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr 4815 NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg 4816 tags: [core, pointwise] 4817 4818- func: neg_(Tensor(a!) self) -> Tensor(a!) 4819 device_check: NoCheck # TensorIterator 4820 structured_delegate: neg.out 4821 variants: function, method 4822 dispatch: 4823 SparseCPU, SparseCUDA: neg_sparse_ 4824 SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_ 4825 NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg_ 4826 tags: pointwise 4827 4828- func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 4829 device_check: NoCheck # TensorIterator 4830 structured: True 4831 structured_inherits: TensorIteratorBase 4832 dispatch: 4833 CPU, CUDA: neg_out 4834 MPS: neg_out_mps 4835 SparseCPU, SparseCUDA: neg_out_sparse 4836 SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_out 4837 tags: pointwise 4838# Alias for neg 4839 4840- func: negative(Tensor self) -> Tensor 4841 variants: function, method 4842 4843- func: negative_(Tensor(a!) self) -> Tensor(a!) 4844 variants: function, method 4845 4846- func: negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 4847 4848- func: repeat(Tensor self, SymInt[] repeats) -> Tensor 4849 variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. 4850 dispatch: 4851 CompositeExplicitAutograd: repeat 4852 MPS: repeat_mps 4853 autogen: repeat.out 4854 tags: core 4855 4856- func: repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor 4857 variants: function 4858 dispatch: 4859 CPU: repeat_interleave_cpu 4860 CUDA: repeat_interleave_cuda 4861 MPS: repeat_interleave_mps 4862 tags: dynamic_output_shape 4863 autogen: repeat_interleave.Tensor_out 4864 4865- func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor 4866 variants: function, method 4867 dispatch: 4868 CompositeImplicitAutograd: repeat_interleave_symint 4869 4870- func: repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor 4871 variants: function, method 4872 dispatch: 4873 CompositeImplicitAutograd: repeat_interleave_symint 4874 4875- func: reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) 4876 variants: function, method 4877 device_check: NoCheck 4878 device_guard: False 4879 dispatch: 4880 CompositeImplicitAutograd: reshape_symint 4881 CompositeImplicitAutogradNestedTensor: reshape_nested_symint 4882 4883- func: _reshape_copy(Tensor self, SymInt[] size) -> Tensor 4884 variants: function 4885 dispatch: 4886 CompositeExplicitAutograd: _reshape_copy_symint 4887 4888# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape. 4889# They are not user-facing, hence the leading underscore. Please don't use it 4890# anywhere else. 4891- func: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) 4892 variants: function, method 4893 device_check: NoCheck 4894 device_guard: False 4895 dispatch: 4896 CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS: _reshape_alias 4897 # We don't need to support mkldnn since this is handled explicitly by the reshape operator. 4898 4899- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor 4900 device_check: NoCheck 4901 device_guard: False 4902 dispatch: 4903 MkldnnCPU: mkldnn_reshape 4904 autogen: _mkldnn_reshape.out 4905 4906- func: reshape_as(Tensor(a) self, Tensor other) -> Tensor(a) 4907 variants: method 4908 device_check: NoCheck 4909 device_guard: False 4910 dispatch: 4911 CompositeImplicitAutograd: reshape_as 4912 CompositeImplicitAutogradNestedTensor: reshape_as_nested 4913 4914- func: round(Tensor self) -> Tensor 4915 device_check: NoCheck # TensorIterator 4916 structured_delegate: round.out 4917 variants: function, method 4918 dispatch: 4919 SparseCPU, SparseCUDA: round_sparse 4920 SparseCsrCPU, SparseCsrCUDA: round_sparse_csr 4921 tags: [core, pointwise] 4922 4923- func: round_(Tensor(a!) self) -> Tensor(a!) 4924 device_check: NoCheck # TensorIterator 4925 structured_delegate: round.out 4926 variants: function, method 4927 dispatch: 4928 SparseCPU, SparseCUDA: round_sparse_ 4929 SparseCsrCPU, SparseCsrCUDA: round_sparse_csr_ 4930 tags: pointwise 4931 4932- func: round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 4933 device_check: NoCheck # TensorIterator 4934 structured: True 4935 structured_inherits: TensorIteratorBase 4936 dispatch: 4937 CPU: round_out 4938 CUDA: round_out 4939 MPS: round_out_mps 4940 SparseCPU, SparseCUDA: round_sparse_out 4941 SparseCsrCPU, SparseCsrCUDA: round_sparse_csr_out 4942 tags: pointwise 4943 4944- func: round.decimals(Tensor self, *, int decimals) -> Tensor 4945 device_check: NoCheck # TensorIterator 4946 structured_delegate: round.decimals_out 4947 variants: function, method 4948 tags: pointwise 4949 4950- func: round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!) 4951 device_check: NoCheck # TensorIterator 4952 structured_delegate: round.decimals_out 4953 variants: function, method 4954 tags: pointwise 4955 4956- func: round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!) 4957 device_check: NoCheck # TensorIterator 4958 structured: True 4959 structured_inherits: TensorIteratorBase 4960 dispatch: 4961 CPU: round_decimals_out 4962 CUDA: round_decimals_out 4963 tags: pointwise 4964 4965- func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor 4966 device_check: NoCheck # TensorIterator 4967 tags: nondeterministic_seeded 4968 4969- func: rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) 4970 tags: nondeterministic_seeded 4971 device_check: NoCheck # TensorIterator 4972 4973- func: relu(Tensor self) -> Tensor 4974 device_check: NoCheck # TensorIterator 4975 variants: function, method 4976 dispatch: 4977 CPU, CUDA: relu 4978 MPS: relu_mps 4979 MkldnnCPU: mkldnn_relu 4980 QuantizedCPU: relu_quantized_cpu 4981 QuantizedCUDA: relu_quantized_cuda 4982 NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu 4983 SparseCPU, SparseCUDA: relu_sparse 4984 SparseCsrCPU, SparseCsrCUDA: relu_sparse_csr 4985 tags: [core, pointwise] 4986 4987- func: relu_(Tensor(a!) self) -> Tensor(a!) 4988 device_check: NoCheck # TensorIterator 4989 variants: function, method 4990 dispatch: 4991 CPU, CUDA: relu_ 4992 MPS: relu_mps_ 4993 MkldnnCPU: mkldnn_relu_ 4994 QuantizedCPU: relu_quantized_cpu_ 4995 QuantizedCUDA: relu_quantized_cuda_ 4996 NestedTensorCPU, NestedTensorCUDA: NestedTensor_relu_ 4997 SparseCPU, SparseCUDA: relu_sparse_ 4998 SparseCsrCPU, SparseCsrCUDA: relu_sparse_csr_ 4999 autogen: relu.out 5000 tags: pointwise 5001 5002- func: relu6(Tensor self) -> Tensor 5003 python_module: nn 5004 5005- func: relu6_(Tensor(a!) self) -> Tensor(a!) 5006 python_module: nn 5007 5008- func: prelu(Tensor self, Tensor weight) -> Tensor 5009 variants: function, method 5010 autogen: prelu.out 5011 5012- func: _prelu_kernel(Tensor self, Tensor weight) -> Tensor 5013 dispatch: 5014 CPU, CUDA: _prelu_kernel 5015 QuantizedCPU: _prelu_kernel_quantized_cpu 5016 MkldnnCPU: mkldnn_prelu 5017 MPS: prelu_mps 5018 5019- func: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) 5020 dispatch: 5021 CPU, CUDA: _prelu_kernel_backward 5022 MkldnnCPU: mkldnn_prelu_backward 5023 MPS: prelu_backward_mps 5024 5025- func: gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) 5026 structured: True 5027 structured_inherits: TensorIteratorBase 5028 device_check: NoCheck # TensorIterator 5029 python_module: nn 5030 dispatch: 5031 CPU: gelu_out_cpu 5032 CUDA: gelu_out_cuda 5033 MPS: gelu_out_mps 5034 5035- func: gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!) 5036 structured_delegate: gelu.out 5037 device_check: NoCheck # TensorIterator 5038 python_module: nn 5039 dispatch: 5040 QuantizedCPU: gelu_quantized_cpu_ 5041 NestedTensorCPU, NestedTensorCUDA: NestedTensor_gelu_ 5042 5043- func: gelu(Tensor self, *, str approximate='none') -> Tensor 5044 structured_delegate: gelu.out 5045 device_check: NoCheck # TensorIterator 5046 python_module: nn 5047 dispatch: 5048 MkldnnCPU: mkldnn_gelu 5049 QuantizedCPU: gelu_quantized_cpu 5050 QuantizedCUDA: gelu_quantized_cuda 5051 NestedTensorCPU, NestedTensorCUDA: NestedTensor_gelu 5052 tags: [core, pointwise] 5053 5054- func: gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!) 5055 structured: True 5056 structured_inherits: TensorIteratorBase 5057 python_module: nn 5058 dispatch: 5059 CPU: gelu_backward_out_cpu 5060 CUDA: gelu_backward_out_cuda 5061 MPS: gelu_backward_out_mps 5062 5063- func: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor 5064 structured_delegate: gelu_backward.grad_input 5065 python_module: nn 5066 dispatch: 5067 MkldnnCPU: mkldnn_gelu_backward 5068 NestedTensorCPU, NestedTensorCUDA: gelu_backwards_nested 5069 tags: pointwise 5070 5071- func: infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor 5072 variants: function 5073 python_module: nn 5074 device_check: NoCheck 5075 device_guard: False 5076 5077- func: hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) 5078 structured: True 5079 structured_inherits: TensorIteratorBase 5080 device_check: NoCheck # TensorIterator 5081 dispatch: 5082 CPU, CUDA: hardshrink_out 5083 5084- func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor 5085 structured_delegate: hardshrink.out 5086 device_check: NoCheck # TensorIterator 5087 variants: function, method 5088 5089- func: hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) 5090 structured: True 5091 structured_inherits: TensorIteratorBase 5092 dispatch: 5093 CPU, CUDA: hardshrink_backward_out 5094 5095- func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor 5096 structured_delegate: hardshrink_backward.grad_input 5097 variants: function, method 5098 5099- func: rsqrt(Tensor self) -> Tensor 5100 device_check: NoCheck # TensorIterator 5101 structured_delegate: rsqrt.out 5102 variants: function, method 5103 tags: [core, pointwise] 5104 5105- func: rsqrt_(Tensor(a!) self) -> Tensor(a!) 5106 device_check: NoCheck # TensorIterator 5107 structured_delegate: rsqrt.out 5108 variants: function, method 5109 tags: pointwise 5110 5111- func: rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5112 device_check: NoCheck # TensorIterator 5113 structured: True 5114 structured_inherits: TensorIteratorBase 5115 dispatch: 5116 CPU, CUDA: rsqrt_out 5117 MPS: rsqrt_out_mps 5118 tags: pointwise 5119 5120- func: select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) 5121 variants: function, method 5122 device_check: NoCheck 5123 device_guard: False 5124 5125- func: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) 5126 variants: function, method 5127 device_check: NoCheck 5128 device_guard: False 5129 dispatch: 5130 CompositeExplicitAutograd: select_symint 5131 SparseCsrCPU, SparseCsrCUDA: select_sparse_csr 5132 NestedTensorCPU, NestedTensorCUDA: select_nested 5133 tags: core 5134 5135- func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor 5136 variants: function 5137 device_check: NoCheck 5138 device_guard: False 5139 dispatch: 5140 CompositeExplicitAutogradNonFunctional: select_backward_symint 5141 autogen: select_backward.out 5142 5143- func: _nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor 5144 variants: function 5145 device_check: NoCheck 5146 device_guard: False 5147 dispatch: 5148 NestedTensorCPU, NestedTensorCUDA: _nested_select_backward_symint 5149 5150- func: selu(Tensor self) -> Tensor 5151 device_check: NoCheck # TensorIterator 5152 5153- func: selu_(Tensor(a!) self) -> Tensor(a!) 5154 device_check: NoCheck # TensorIterator 5155 5156- func: celu(Tensor self, Scalar alpha=1.0) -> Tensor 5157 device_check: NoCheck # TensorIterator 5158 dispatch: 5159 CompositeExplicitAutograd: celu 5160 5161- func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) 5162 device_check: NoCheck # TensorIterator 5163 dispatch: 5164 CompositeExplicitAutograd: celu_ 5165 autogen: celu.out 5166 5167- func: silu(Tensor self) -> Tensor 5168 structured_delegate: silu.out 5169 python_module: nn 5170 dispatch: 5171 NestedTensorCPU, NestedTensorCUDA: NestedTensor_silu 5172 tags: pointwise 5173 5174- func: silu_(Tensor(a!) self) -> Tensor(a!) 5175 structured_delegate: silu.out 5176 python_module: nn 5177 dispatch: 5178 NestedTensorCPU, NestedTensorCUDA: NestedTensor_silu_ 5179 tags: pointwise 5180 5181- func: silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5182 structured: True 5183 structured_inherits: TensorIteratorBase 5184 python_module: nn 5185 dispatch: 5186 CPU, CUDA: silu_out 5187 MPS: silu_out_mps 5188 tags: pointwise 5189 5190- func: silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) 5191 structured: True 5192 structured_inherits: TensorIteratorBase 5193 python_module: nn 5194 dispatch: 5195 CPU, CUDA: silu_backward_out 5196 MPS: silu_backward_out_mps 5197 tags: pointwise 5198 5199- func: silu_backward(Tensor grad_output, Tensor self) -> Tensor 5200 structured_delegate: silu_backward.grad_input 5201 python_module: nn 5202 dispatch: 5203 CompositeImplicitAutograd: math_silu_backward 5204 NestedTensorCPU, NestedTensorCUDA: silu_backward_nested 5205 tags: pointwise 5206 5207- func: mish(Tensor self) -> Tensor 5208 structured_delegate: mish.out 5209 python_module: nn 5210 5211- func: mish_(Tensor(a!) self) -> Tensor(a!) 5212 structured_delegate: mish.out 5213 python_module: nn 5214 5215- func: mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5216 structured: True 5217 structured_inherits: TensorIteratorBase 5218 python_module: nn 5219 dispatch: 5220 CPU, CUDA: mish_out 5221 MPS: mish_out_mps 5222 5223- func: mish_backward(Tensor grad_output, Tensor self) -> Tensor 5224 python_module: nn 5225 dispatch: 5226 CPU, CUDA: mish_backward 5227 MPS: mish_backward_mps 5228 CompositeImplicitAutograd: math_mish_backward 5229 5230- func: sigmoid(Tensor self) -> Tensor 5231 device_check: NoCheck # TensorIterator 5232 structured_delegate: sigmoid.out 5233 variants: function, method 5234 dispatch: 5235 QuantizedCPU: sigmoid_quantized_cpu 5236 MkldnnCPU: mkldnn_sigmoid 5237 tags: [core, pointwise] 5238 5239- func: sigmoid_(Tensor(a!) self) -> Tensor(a!) 5240 device_check: NoCheck # TensorIterator 5241 structured_delegate: sigmoid.out 5242 variants: function, method 5243 dispatch: 5244 MkldnnCPU: mkldnn_sigmoid_ 5245 tags: pointwise 5246 5247- func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5248 device_check: NoCheck # TensorIterator 5249 structured: True 5250 structured_inherits: TensorIteratorBase 5251 dispatch: 5252 CPU, CUDA: sigmoid_out 5253 MPS: sigmoid_out_mps 5254 tags: pointwise 5255 5256- func: logit(Tensor self, float? eps=None) -> Tensor 5257 variants: function, method 5258 dispatch: 5259 CPU, CUDA: logit 5260 MPS: logit_mps 5261 tags: pointwise 5262 5263- func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) 5264 variants: function, method 5265 dispatch: 5266 CPU, CUDA: logit_ 5267 tags: pointwise 5268 5269- func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) 5270 dispatch: 5271 CPU, CUDA: logit_out 5272 MPS: logit_out_mps 5273 tags: pointwise 5274 5275- func: sin(Tensor self) -> Tensor 5276 device_check: NoCheck # TensorIterator 5277 structured_delegate: sin.out 5278 variants: function, method 5279 dispatch: 5280 SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr 5281 SparseCPU, SparseCUDA: sin_sparse 5282 NestedTensorCPU, NestedTensorCUDA: sin_nested 5283 tags: [core, pointwise] 5284 5285- func: sin_(Tensor(a!) self) -> Tensor(a!) 5286 device_check: NoCheck # TensorIterator 5287 structured_delegate: sin.out 5288 variants: function, method 5289 dispatch: 5290 SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr_ 5291 SparseCPU, SparseCUDA: sin_sparse_ 5292 tags: pointwise 5293 5294- func: sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5295 device_check: NoCheck # TensorIterator 5296 structured: True 5297 structured_inherits: TensorIteratorBase 5298 dispatch: 5299 CPU, CUDA: sin_out 5300 MPS: sin_out_mps 5301 SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr_out 5302 SparseCPU, SparseCUDA: sin_sparse_out 5303 tags: pointwise 5304 5305- func: sinc(Tensor self) -> Tensor 5306 structured_delegate: sinc.out 5307 variants: function, method 5308 tags: pointwise 5309 5310- func: sinc_(Tensor(a!) self) -> Tensor(a!) 5311 structured_delegate: sinc.out 5312 variants: function, method 5313 tags: pointwise 5314 5315- func: sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5316 structured: True 5317 structured_inherits: TensorIteratorBase 5318 dispatch: 5319 CPU, CUDA: sinc_out 5320 tags: pointwise 5321 5322- func: sinh(Tensor self) -> Tensor 5323 device_check: NoCheck # TensorIterator 5324 structured_delegate: sinh.out 5325 variants: function, method 5326 dispatch: 5327 SparseCPU, SparseCUDA: sinh_sparse 5328 SparseCsrCPU, SparseCsrCUDA: sinh_sparse_csr 5329 tags: [core, pointwise] 5330 5331- func: sinh_(Tensor(a!) self) -> Tensor(a!) 5332 device_check: NoCheck # TensorIterator 5333 structured_delegate: sinh.out 5334 variants: function, method 5335 dispatch: 5336 SparseCPU, SparseCUDA: sinh_sparse_ 5337 SparseCsrCPU, SparseCsrCUDA: sinh_sparse_csr_ 5338 tags: pointwise 5339 5340- func: sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5341 device_check: NoCheck # TensorIterator 5342 structured: True 5343 structured_inherits: TensorIteratorBase 5344 dispatch: 5345 CPU, CUDA: sinh_out 5346 MPS: sinh_out_mps 5347 SparseCPU, SparseCUDA: sinh_sparse_out 5348 SparseCsrCPU, SparseCsrCUDA: sinh_sparse_csr_out 5349 5350# Returns a copy of this `Variable` that is detached from its autograd graph. 5351# This method is OK to call if the `Variable` is a view. 5352# 5353# NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides / 5354# storage / storage_offset) of a tensor created from `detach()`, those metadata 5355# in the original tensor will also be updated. However, the new behavior is that 5356# those metadata changes to the detached tensor will not update the original tensor 5357# anymore, and in the `detach()` function we need to set `allow_tensor_metadata_change_` 5358# to false to make such changes explicitly illegal, in order to prevent users from 5359# changing metadata of the detached tensor and expecting the original tensor to also 5360# be updated. 5361 tags: pointwise 5362- func: detach(Tensor(a) self) -> Tensor(a) 5363 variants: function, method 5364 dispatch: 5365 CompositeExplicitAutograd: detach 5366 NestedTensorCPU, NestedTensorCUDA: detach 5367 5368# Like `detach()`, but modifies this `Variable` in-place. This method may 5369# only be called on non-view `Variable`s. You can use `is_view()` to check 5370# this. If this `Variable` is a view, throws an `std::runtime_error()`. 5371- func: detach_(Tensor(a!) self) -> Tensor(a!) 5372 variants: function, method 5373 tags: inplace_view 5374 dispatch: 5375 CompositeExplicitAutograd: detach_ 5376 5377- func: size.int(Tensor self, int dim) -> int 5378 variants: function 5379 device_check: NoCheck 5380 device_guard: False 5381 manual_cpp_binding: True 5382 5383- func: size.Dimname(Tensor self, Dimname dim) -> int 5384 variants: function, method 5385 device_check: NoCheck 5386 device_guard: False 5387 5388- func: sym_size.int(Tensor self, int dim) -> SymInt 5389 variants: function 5390 device_check: NoCheck 5391 device_guard: False 5392 tags: core 5393 manual_cpp_binding: True 5394 5395- func: sym_numel(Tensor self) -> SymInt 5396 variants: function 5397 device_check: NoCheck 5398 device_guard: False 5399 tags: core 5400 manual_cpp_binding: True 5401 5402- func: sym_storage_offset(Tensor self) -> SymInt 5403 variants: function 5404 device_check: NoCheck 5405 device_guard: False 5406 tags: core 5407 manual_cpp_binding: True 5408 5409- func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) 5410 variants: function, method 5411 device_check: NoCheck 5412 device_guard: False 5413 dispatch: 5414 CompositeExplicitAutograd: slice 5415 tags: core 5416 5417# NOTE: The implementation of split_with_sizes bypasses the dispatcher to call this; undo 5418# that if adding specific implementations here! 5419 5420- func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor 5421 variants: function 5422 device_check: NoCheck 5423 device_guard: False 5424 dispatch: 5425 CompositeExplicitAutograd: slice_backward 5426 autogen: slice_backward.out 5427 5428# NB: This op exists to back the implementation of reverse view_funcs for various views (chunk, 5429# slice.Tensor, split_with_sizes, et al.). Currently, these are only used during fake-ification 5430# of PT2 graph input subclass instances that are views. This means: 5431# * This op shouldn't really show up in eager mode (so e.g. XLA shouldn't have to implement it) 5432# * This op shouldn't show up in a PT2 graph (so a PT2 backend shouldn't have to implement it) 5433# * A subclass will have to implement this to work in PT2 if a subclass view is used as a graph 5434# input AND the view utilizes this op in its inverse. The idea is that slice_inverse() is 5435# easier to implement for a subclass than as_strided() 5436- func: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) 5437 variants: function, method 5438 device_check: NoCheck 5439 device_guard: False 5440 dispatch: 5441 CompositeExplicitAutograd: slice_inverse_symint 5442 5443- func: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor 5444 variants: function, method 5445 device_check: NoCheck 5446 device_guard: False 5447 dispatch: 5448 CompositeExplicitAutogradNonFunctional: slice_scatter 5449 autogen: slice_scatter.out 5450 tags: [core, view_copy] 5451 5452- func: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor 5453 variants: function, method 5454 device_check: NoCheck 5455 device_guard: False 5456 dispatch: 5457 CompositeExplicitAutogradNonFunctional: select_scatter_symint 5458 autogen: select_scatter.out 5459 tags: core 5460 5461- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor 5462 variants: function, method 5463 device_check: NoCheck 5464 device_guard: False 5465 dispatch: 5466 CompositeExplicitAutogradNonFunctional: diagonal_scatter 5467 autogen: diagonal_scatter.out 5468 5469- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor 5470 variants: function, method 5471 device_check: NoCheck 5472 device_guard: False 5473 dispatch: 5474 CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint 5475 autogen: as_strided_scatter.out 5476 5477- func: smm(Tensor self, Tensor mat2) -> Tensor 5478 variants: function, method 5479 5480# softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. 5481- func: softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor 5482 variants: function, method 5483 5484- func: softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) 5485 variants: function 5486 dispatch: 5487 CompositeExplicitAutograd: softmax_out 5488 5489- func: softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor 5490 variants: function, method 5491 5492- func: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor 5493 structured_delegate: _softmax.out 5494 dispatch: 5495 MkldnnCPU: mkldnn_softmax 5496 NestedTensorCPU, NestedTensorCUDA: softmax_nested 5497 tags: core 5498 5499- func: _softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) 5500 structured: True 5501 dispatch: 5502 CPU: softmax_cpu_out 5503 CUDA: softmax_cuda_out 5504 MPS: softmax_mps_out 5505 5506- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor 5507 structured_delegate: _softmax_backward_data.out 5508 dispatch: 5509 NestedTensorCPU, NestedTensorCUDA: nested_softmax_backward 5510 5511- func: _softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!) 5512 structured: True 5513 dispatch: 5514 CPU: softmax_backward_cpu_out 5515 CUDA: softmax_backward_cuda_out 5516 MPS: softmax_backward_mps_out 5517 5518- func: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] 5519 variants: function, method 5520 device_check: NoCheck 5521 device_guard: False 5522 dispatch: 5523 CompositeExplicitAutograd: unsafe_split 5524 autogen: unsafe_split.Tensor_out 5525 5526- func: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] 5527 variants: function, method 5528 device_check: NoCheck 5529 device_guard: False 5530 dispatch: 5531 CompositeExplicitAutograd: split 5532 5533- func: split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] 5534 variants: function, method 5535 device_guard: False 5536 dispatch: 5537 CompositeImplicitAutograd: split_symint 5538 5539- func: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] 5540 variants: function, method 5541 device_check: NoCheck 5542 device_guard: False 5543 dispatch: 5544 CompositeExplicitAutograd: unsafe_split_with_sizes 5545 autogen: unsafe_split_with_sizes.out 5546 5547- func: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] 5548 variants: function, method 5549 device_check: NoCheck 5550 device_guard: False 5551 dispatch: 5552 CompositeExplicitAutograd: split_with_sizes 5553 NestedTensorCPU, NestedTensorCUDA: split_with_sizes_nested 5554 tags: core 5555 5556- func: hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] 5557 variants: function, method 5558 5559- func: hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] 5560 variants: function, method 5561 5562- func: vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] 5563 variants: function, method 5564 5565- func: vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] 5566 variants: function, method 5567 5568- func: dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] 5569 variants: function, method 5570 5571- func: dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] 5572 variants: function, method 5573 5574- func: squeeze(Tensor(a) self) -> Tensor(a) 5575 variants: function, method 5576 device_check: NoCheck 5577 device_guard: False 5578 dispatch: 5579 CompositeExplicitAutograd: squeeze 5580 QuantizedCPU, QuantizedCUDA: squeeze_quantized 5581 NestedTensorCPU, NestedTensorCUDA: squeeze_nested 5582 5583- func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) 5584 variants: function, method 5585 device_check: NoCheck 5586 device_guard: False 5587 dispatch: 5588 CompositeExplicitAutograd: squeeze 5589 QuantizedCPU, QuantizedCUDA: squeeze_quantized 5590 NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested 5591 tags: core 5592 5593- func: squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a) 5594 variants: function, method 5595 device_check: NoCheck 5596 device_guard: False 5597 5598 5599- func: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) 5600 variants: function, method 5601 device_check: NoCheck 5602 device_guard: False 5603 dispatch: 5604 CompositeExplicitAutograd: squeeze 5605 QuantizedCPU, QuantizedCUDA: squeeze_quantized 5606 NestedTensorCPU, NestedTensorCUDA: squeeze_dim_nested 5607 tags: core 5608 5609- func: squeeze_(Tensor(a!) self) -> Tensor(a!) 5610 variants: method 5611 device_check: NoCheck 5612 device_guard: False 5613 tags: inplace_view 5614 dispatch: 5615 CompositeExplicitAutograd: squeeze_ 5616 5617- func: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) 5618 variants: method 5619 device_check: NoCheck 5620 device_guard: False 5621 tags: inplace_view 5622 dispatch: 5623 CompositeExplicitAutograd: squeeze_ 5624 5625- func: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) 5626 variants: method 5627 device_check: NoCheck 5628 device_guard: False 5629 tags: inplace_view 5630 dispatch: 5631 CompositeExplicitAutograd: squeeze_ 5632 5633- func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) 5634 variants: method 5635 device_check: NoCheck 5636 device_guard: False 5637 tags: inplace_view 5638 5639- func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 5640 variants: function, method 5641 5642- func: sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 5643 dispatch: 5644 CPU: _sspaddmm_out_only_sparse 5645 CUDA: _sspaddmm_out_only_sparse_cuda 5646 SparseCPU: _sspaddmm_out_cpu 5647 SparseCUDA: _sspaddmm_out_cuda 5648 5649- func: _chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor 5650 dispatch: 5651 CompositeExplicitAutograd: _chunk_cat 5652 CUDA: _chunk_cat_cuda 5653 5654- func: _chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!) 5655 dispatch: 5656 CompositeExplicitAutograd: _chunk_cat_out 5657 CUDA: _chunk_cat_out_cuda 5658 5659- func: stack(Tensor[] tensors, int dim=0) -> Tensor 5660 dispatch: 5661 CompositeExplicitAutograd: stack 5662 5663- func: stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) 5664 dispatch: 5665 CompositeExplicitAutograd: stack_out 5666 5667- func: _stack(Tensor[] tensors, int dim=0) -> Tensor 5668 dispatch: # match the backends supported by _cat 5669 CPU: _stack_cpu 5670 CompositeExplicitAutograd: _stack 5671 5672- func: _stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) 5673 dispatch: # match the backends supported by _cat_out 5674 CPU: _stack_out_cpu 5675 CompositeExplicitAutograd: _stack_out 5676 5677- func: hstack(Tensor[] tensors) -> Tensor 5678 5679- func: hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) 5680 5681- func: vstack(Tensor[] tensors) -> Tensor 5682 5683- func: vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) 5684 5685- func: dstack(Tensor[] tensors) -> Tensor 5686 5687- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) 5688 5689# Overload without center & pad mode, needed for forward-compatibility 5690- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor 5691 variants: function, method 5692 cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized'] 5693 5694- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor 5695 variants: function, method 5696 5697- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor 5698 variants: function, method 5699 5700- func: stride.int(Tensor self, int dim) -> int 5701 variants: function 5702 device_check: NoCheck 5703 device_guard: False 5704 manual_cpp_binding: True 5705 5706- func: stride.Dimname(Tensor self, Dimname dim) -> int 5707 variants: function, method 5708 device_check: NoCheck 5709 device_guard: False 5710 5711- func: sym_stride.int(Tensor self, int dim) -> SymInt 5712 variants: function 5713 device_check: NoCheck 5714 device_guard: False 5715 tags: core 5716 manual_cpp_binding: True 5717 5718- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor 5719 device_check: NoCheck # TensorIterator 5720 variants: function, method 5721 dispatch: 5722 CompositeExplicitAutograd: sum 5723 SparseCPU, SparseCUDA, SparseMeta: sum_coo 5724 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr 5725 autogen: sum.out 5726 5727- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 5728 # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype 5729 structured_delegate: sum.IntList_out 5730 device_check: NoCheck # TensorIterator 5731 variants: function, method 5732 dispatch: 5733 NestedTensorCPU: NestedTensor_sum_dim_CPU 5734 SparseCPU, SparseCUDA: sum_sparse_coo 5735 SparseCsrCPU, SparseCsrCUDA: sum_sparse_compressed 5736 tags: core 5737 5738- func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 5739 device_check: NoCheck # TensorIterator 5740 variants: function, method 5741 5742- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 5743 structured: True 5744 device_check: NoCheck # TensorIterator 5745 dispatch: 5746 CPU, CUDA: sum_out 5747 MPS: sum_out_mps 5748 5749- func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 5750 device_check: NoCheck # TensorIterator 5751 5752# TODO: this function will be replaced once nested expand semantics have been settled on 5753- func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor 5754 dispatch: 5755 NestedTensorCPU: _nested_sum_backward_cpu 5756 5757- func: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 5758 variants: function, method 5759 dispatch: 5760 CPU, CUDA: nansum 5761 MPS: nansum_mps 5762 5763- func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 5764 dispatch: 5765 CPU, CUDA: nansum_out 5766 MPS: nansum_out_mps 5767 5768- func: sum_to_size(Tensor self, SymInt[] size) -> Tensor 5769 variants: method 5770 device_check: NoCheck 5771 device_guard: False 5772 dispatch: 5773 CompositeImplicitAutograd: sum_to_size_symint 5774 5775- func: sqrt(Tensor self) -> Tensor 5776 device_check: NoCheck # TensorIterator 5777 structured_delegate: sqrt.out 5778 variants: function, method 5779 dispatch: 5780 SparseCPU, SparseCUDA: sqrt_sparse 5781 SparseCsrCPU, SparseCsrCUDA: sqrt_sparse_csr 5782 tags: [core, pointwise] 5783 5784- func: sqrt_(Tensor(a!) self) -> Tensor(a!) 5785 device_check: NoCheck # TensorIterator 5786 structured_delegate: sqrt.out 5787 variants: function, method 5788 dispatch: 5789 SparseCPU, SparseCUDA: sqrt_sparse_ 5790 SparseCsrCPU, SparseCsrCUDA: sqrt_sparse_csr_ 5791 tags: pointwise 5792 5793- func: sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5794 device_check: NoCheck # TensorIterator 5795 structured: True 5796 structured_inherits: TensorIteratorBase 5797 dispatch: 5798 CPU, CUDA: sqrt_out 5799 MPS: sqrt_out_mps 5800 SparseCPU, SparseCUDA: sqrt_sparse_out 5801 SparseCsrCPU, SparseCsrCUDA: sqrt_sparse_csr_out 5802 tags: pointwise 5803 5804- func: square(Tensor self) -> Tensor 5805 device_check: NoCheck # TensorIterator 5806 variants: function, method 5807 tags: pointwise 5808 5809- func: square_(Tensor(a!) self) -> Tensor(a!) 5810 device_check: NoCheck # TensorIterator 5811 variants: function, method 5812 tags: pointwise 5813 5814- func: square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5815 tags: pointwise 5816 5817- func: std(Tensor self, bool unbiased=True) -> Tensor 5818 device_check: NoCheck # TensorIterator 5819 variants: function, method 5820 cpp_no_default_args: ["unbiased"] 5821 5822- func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor 5823 device_check: NoCheck # TensorIterator 5824 variants: function, method 5825 cpp_no_default_args: ["unbiased"] 5826 5827- func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor 5828 device_check: NoCheck # TensorIterator 5829 variants: function, method 5830 dispatch: 5831 CPU, CUDA: std 5832 MPS: std_mps 5833 QuantizedCPU: std_quantized_cpu 5834 5835- func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) 5836 device_check: NoCheck # TensorIterator 5837 variants: function 5838 cpp_no_default_args: ["unbiased"] 5839 5840- func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) 5841 device_check: NoCheck # TensorIterator 5842 variants: function 5843 cpp_no_default_args: ["unbiased"] 5844 5845- func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) 5846 device_check: NoCheck # TensorIterator 5847 variants: function 5848 dispatch: 5849 CPU, CUDA: std_mean 5850 MPS: std_mean_mps 5851 autogen: std_mean.correction_out 5852 5853- func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) 5854 device_check: NoCheck # TensorIterator 5855 variants: function 5856 cpp_no_default_args: ["unbiased"] 5857 5858- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) 5859 device_check: NoCheck # TensorIterator 5860 variants: function 5861 5862- func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 5863 device_check: NoCheck # TensorIterator 5864 cpp_no_default_args: ["unbiased"] 5865 5866- func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) 5867 device_check: NoCheck # TensorIterator 5868 dispatch: 5869 CPU, CUDA: std_out 5870 QuantizedCPU: std_out_quantized_cpu 5871 5872- func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor 5873 device_check: NoCheck # TensorIterator 5874 variants: function, method 5875 cpp_no_default_args: ["unbiased"] 5876 5877- func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 5878 device_check: NoCheck # TensorIterator 5879 cpp_no_default_args: ["unbiased"] 5880 5881- func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor 5882 device_check: NoCheck # TensorIterator 5883 variants: function, method 5884 5885- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) 5886 device_check: NoCheck # TensorIterator 5887 variants: function 5888 5889- func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor 5890 device_check: NoCheck # TensorIterator 5891 variants: function, method 5892 dispatch: 5893 CPU, CUDA: prod 5894 MPS: prod_mps 5895 autogen: prod.out 5896 tags: core 5897 5898- func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 5899 structured_delegate: prod.int_out 5900 device_check: NoCheck # TensorIterator 5901 variants: function, method 5902 tags: core 5903 5904- func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 5905 structured: True 5906 device_check: NoCheck # TensorIterator 5907 dispatch: 5908 CPU, CUDA: prod_out 5909 MPS: prod_out_mps 5910 5911- func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 5912 device_check: NoCheck # TensorIterator 5913 variants: function, method 5914 5915- func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 5916 device_check: NoCheck # TensorIterator 5917 5918- func: t(Tensor(a) self) -> Tensor(a) 5919 device_check: NoCheck 5920 device_guard: False 5921 variants: function, method 5922 dispatch: 5923 CompositeExplicitAutograd: t 5924 5925- func: t_(Tensor(a!) self) -> Tensor(a!) 5926 device_check: NoCheck 5927 device_guard: False 5928 variants: method 5929 tags: inplace_view 5930 dispatch: 5931 CompositeExplicitAutograd: t_ 5932 5933- func: tan(Tensor self) -> Tensor 5934 device_check: NoCheck # TensorIterator 5935 structured_delegate: tan.out 5936 variants: function, method 5937 dispatch: 5938 SparseCPU, SparseCUDA: tan_sparse 5939 SparseCsrCPU, SparseCsrCUDA: tan_sparse_csr 5940 tags: [core, pointwise] 5941 5942- func: tan_(Tensor(a!) self) -> Tensor(a!) 5943 device_check: NoCheck # TensorIterator 5944 structured_delegate: tan.out 5945 variants: function, method 5946 dispatch: 5947 SparseCPU, SparseCUDA: tan_sparse_ 5948 SparseCsrCPU, SparseCsrCUDA: tan_sparse_csr_ 5949 tags: pointwise 5950 5951- func: tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5952 device_check: NoCheck # TensorIterator 5953 structured: True 5954 structured_inherits: TensorIteratorBase 5955 dispatch: 5956 CPU, CUDA: tan_out 5957 MPS: tan_out_mps 5958 SparseCPU, SparseCUDA: tan_sparse_out 5959 SparseCsrCPU, SparseCsrCUDA: tan_sparse_csr_out 5960 tags: pointwise 5961 5962- func: tanh(Tensor self) -> Tensor 5963 device_check: NoCheck # TensorIterator 5964 structured_delegate: tanh.out 5965 variants: function, method 5966 dispatch: 5967 QuantizedCPU: tanh_quantized_cpu 5968 MkldnnCPU: mkldnn_tanh 5969 SparseCPU, SparseCUDA: tanh_sparse 5970 SparseCsrCPU, SparseCsrCUDA: tanh_sparse_csr 5971 NestedTensorCPU, NestedTensorCUDA: NestedTensor_tanh 5972 tags: [core, pointwise] 5973 5974- func: tanh_(Tensor(a!) self) -> Tensor(a!) 5975 device_check: NoCheck # TensorIterator 5976 structured_delegate: tanh.out 5977 variants: function, method 5978 dispatch: 5979 MkldnnCPU: mkldnn_tanh_ 5980 SparseCPU, SparseCUDA: tanh_sparse_ 5981 SparseCsrCPU, SparseCsrCUDA: tanh_sparse_csr_ 5982 NestedTensorCPU, NestedTensorCUDA: NestedTensor_tanh_ 5983 tags: pointwise 5984 5985- func: tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 5986 device_check: NoCheck # TensorIterator 5987 structured: True 5988 structured_inherits: TensorIteratorBase 5989 dispatch: 5990 CPU, CUDA: tanh_out 5991 MPS: tanh_out_mps 5992 SparseCPU, SparseCUDA: tanh_sparse_out 5993 SparseCsrCPU, SparseCsrCUDA: tanh_sparse_csr_out 5994 tags: pointwise 5995 5996- func: tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor 5997 variants: function 5998 5999- func: tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!) 6000 variants: function 6001 6002# TODO: namespace threshold in 'nn' 6003- func: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor 6004 device_check: NoCheck # TensorIterator 6005 variants: function 6006 structured_delegate: threshold.out 6007 dispatch: 6008 QuantizedCPU: threshold_quantized_cpu 6009 6010- func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) 6011 device_check: NoCheck # TensorIterator 6012 variants: function 6013 structured_delegate: threshold.out 6014 6015- func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) 6016 device_check: NoCheck # TensorIterator 6017 structured: True 6018 structured_inherits: TensorIteratorBase 6019 dispatch: 6020 CPU, CUDA: threshold_out 6021 MPS: threshold_out_mps 6022 6023- func: threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) 6024 structured: True 6025 structured_inherits: TensorIteratorBase 6026 dispatch: 6027 CPU, CUDA: threshold_backward_out 6028 MPS: threshold_backward_out_mps 6029 SparseCPU, SparseCUDA: threshold_backward_sparse_out 6030 SparseCsrCPU, SparseCsrCUDA: threshold_backward_sparse_compressed_out 6031 6032- func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor 6033 variants: function 6034 structured_delegate: threshold_backward.grad_input 6035 dispatch: 6036 MkldnnCPU: mkldnn_relu_backward 6037 SparseCPU, SparseCUDA: threshold_backward_sparse 6038 SparseCsrCPU, SparseCsrCUDA: threshold_backward_sparse_compressed 6039 NestedTensorCPU, NestedTensorCUDA: threshold_backwards_nested 6040 tags: pointwise 6041 6042- func: tile(Tensor self, SymInt[] dims) -> Tensor 6043 variants: function, method 6044 dispatch: 6045 CompositeImplicitAutograd: tile_symint 6046 6047- func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) 6048 variants: function, method 6049 device_check: NoCheck 6050 device_guard: False 6051 dispatch: 6052 CompositeExplicitAutograd: transpose 6053 NestedTensorCPU, NestedTensorCUDA: transpose_nested 6054 6055- func: transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) 6056 variants: function, method 6057 device_check: NoCheck 6058 device_guard: False 6059 6060- func: _mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor 6061 device_check: NoCheck 6062 device_guard: False 6063 dispatch: 6064 MkldnnCPU: mkldnn_transpose 6065 6066- func: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) 6067 variants: method 6068 device_check: NoCheck 6069 device_guard: False 6070 tags: inplace_view 6071 dispatch: 6072 CompositeExplicitAutograd: transpose_ 6073 6074- func: _mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) 6075 device_check: NoCheck 6076 device_guard: False 6077 dispatch: 6078 MkldnnCPU: mkldnn_transpose_ 6079 autogen: _mkldnn_transpose.out 6080 6081- func: one_hot(Tensor self, int num_classes=-1) -> Tensor 6082 python_module: nn 6083 variants: function 6084 tags: dynamic_output_shape 6085 6086- func: flip(Tensor self, int[] dims) -> Tensor 6087 variants: function, method 6088 dispatch: 6089 CPU, QuantizedCPU, CUDA, QuantizedCUDA: flip 6090 MPS: flip_mps 6091 autogen: flip.out 6092 tags: core 6093 6094- func: fliplr(Tensor self) -> Tensor 6095 variants: function, method 6096 6097- func: flipud(Tensor self) -> Tensor 6098 variants: function, method 6099 6100- func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor 6101 variants: function, method 6102 dispatch: 6103 CPU, MPS: roll 6104 CUDA: roll_cuda 6105 autogen: roll.out 6106 6107# default int[] value [0,1] should not add space after comma, since codegen parser uses ', ' to split args 6108 6109- func: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor 6110 variants: function, method 6111 dispatch: 6112 CompositeExplicitAutograd: rot90 6113 autogen: rot90.out 6114 6115- func: trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor 6116 6117- func: trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor 6118 6119- func: trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor 6120 6121- func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor 6122 6123# Fused implementation detail for transformers. Adds in-projection bias to QKV and divides Q by sqrt(D/num_heads). 6124- func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor) 6125 dispatch: 6126 CPU, NestedTensorCPU: transform_bias_rescale_qkv_cpu 6127 CUDA, NestedTensorCUDA: transform_bias_rescale_qkv_cuda 6128 autogen: _transform_bias_rescale_qkv.out 6129 6130- func: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor 6131 dispatch: 6132 CPU, CUDA: NestedTensor_nested_tensor_from_mask 6133 autogen: _nested_tensor_from_mask.out 6134 6135- func: _nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool 6136 dispatch: 6137 CPU, CUDA: NestedTensor_nested_tensor_from_mask_left_aligned 6138 6139- func: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor 6140 device_check: NoCheck # cpu_nested_shape_example will always be on CPU 6141 dispatch: 6142 CPU: nested_from_padded_generic 6143 CUDA: nested_from_padded_cuda 6144 autogen: _nested_from_padded.out 6145 6146# These private functions are temporary. They will be updated/deleted when nested tensors switch to using SymInts for their metadata representation 6147- func: _nested_tensor_size(Tensor self) -> Tensor 6148 variants: method 6149 dispatch: 6150 NestedTensorCPU, NestedTensorCUDA: _nested_tensor_size 6151 autogen: _nested_tensor_size.out 6152 6153- func: _nested_tensor_strides(Tensor self) -> Tensor 6154 variants: method 6155 dispatch: 6156 NestedTensorCPU, NestedTensorCUDA: _nested_tensor_strides 6157 autogen: _nested_tensor_strides.out 6158 6159- func: _nested_tensor_storage_offsets(Tensor self) -> Tensor 6160 variants: method 6161 dispatch: 6162 NestedTensorCPU, NestedTensorCUDA, NestedTensorMeta: _nested_tensor_storage_offsets 6163 autogen: _nested_tensor_storage_offsets.out 6164 6165# _nested_from_padded is not usable from Python, so 6166# _nested_from_padded_and_nested_example is available for testing. 6167- func: _nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor 6168 dispatch: 6169 NestedTensorCPU, NestedTensorCUDA: NestedTensor_from_padded_and_nested_example 6170 autogen: _nested_from_padded_and_nested_example.out 6171 6172# The input arguments' types to this functions are temporary. When nested tensors switch to using SymInts for their metadata representation 6173# this will need to be updated 6174- func: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) 6175 variants: function 6176 device_check: NoCheck 6177 dispatch: 6178 CPU, CUDA: _nested_view_from_buffer 6179 6180- func: _nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor 6181 variants: function 6182 device_check: NoCheck 6183 tags: view_copy 6184 dispatch: 6185 CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy 6186 autogen: _nested_view_from_buffer_copy.out 6187 6188- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor(a) 6189 variants: function 6190 device_check: NoCheck 6191 dispatch: {} 6192 6193- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1) -> Tensor 6194 variants: function 6195 device_check: NoCheck 6196 tags: view_copy 6197 dispatch: 6198 CompositeExplicitAutogradNonFunctional: _nested_view_from_jagged_copy 6199 autogen: _nested_view_from_jagged_copy.out 6200 6201- func: _nested_get_values(Tensor(a) self) -> Tensor(a) 6202 variants: function 6203 device_check: NoCheck 6204 dispatch: {} 6205 6206- func: _nested_get_values_copy(Tensor self) -> Tensor 6207 variants: function 6208 device_check: NoCheck 6209 tags: view_copy 6210 dispatch: 6211 CompositeExplicitAutogradNonFunctional: _nested_get_values_copy 6212 autogen: _nested_get_values_copy.out 6213 6214- func: _nested_get_offsets(Tensor self) -> Tensor 6215 variants: function 6216 device_check: NoCheck 6217 dispatch: {} 6218 6219# returns undefined Tensor if no lengths present 6220- func: _nested_get_lengths(Tensor self) -> Tensor 6221 variants: function 6222 device_check: NoCheck 6223 dispatch: {} 6224 6225- func: _nested_get_ragged_idx(Tensor self) -> int 6226 variants: function 6227 device_check: NoCheck 6228 dispatch: {} 6229 6230- func: _nested_get_jagged_dummy(Tensor any) -> Tensor 6231 category_override: dummy 6232 dispatch: {} 6233 6234- func: _nested_compute_contiguous_strides_offsets(Tensor nested_size) -> (Tensor, Tensor) 6235 variants: function 6236 device_check: NoCheck 6237 dispatch: 6238 CPU, CUDA: _nested_compute_contiguous_strides_offsets 6239 6240- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor 6241 dispatch: 6242 # calls unsqueeze 6243 CompositeExplicitAutogradNonFunctional: _trilinear 6244 autogen: _trilinear.out 6245 6246- func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor 6247 6248- func: trunc(Tensor self) -> Tensor 6249 structured_delegate: trunc.out 6250 device_check: NoCheck # TensorIterator 6251 variants: function, method 6252 dispatch: 6253 SparseCPU, SparseCUDA: trunc_sparse 6254 SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr 6255 tags: [core, pointwise] 6256 6257- func: trunc_(Tensor(a!) self) -> Tensor(a!) 6258 structured_delegate: trunc.out 6259 device_check: NoCheck # TensorIterator 6260 variants: function, method 6261 dispatch: 6262 SparseCPU, SparseCUDA: trunc_sparse_ 6263 SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr_ 6264 tags: pointwise 6265 6266- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 6267 structured: True 6268 structured_inherits: TensorIteratorBase 6269 device_check: NoCheck # TensorIterator 6270 dispatch: 6271 CPU, CUDA: trunc_out 6272 MPS: trunc_out_mps 6273 SparseCPU, SparseCUDA: trunc_sparse_out 6274 SparseCsrCPU, SparseCsrCUDA: trunc_sparse_csr_out 6275 tags: pointwise 6276# Alias for trunc 6277 6278- func: fix(Tensor self) -> Tensor 6279 variants: function, method 6280 6281- func: fix_(Tensor(a!) self) -> Tensor(a!) 6282 variants: function, method 6283 6284- func: fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 6285 6286- func: type_as(Tensor self, Tensor other) -> Tensor 6287 variants: method 6288 6289- func: _has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool 6290 variants: function 6291 6292- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) 6293 variants: function 6294 dispatch: 6295 CPU: _unique_cpu 6296 CUDA: _unique_cuda 6297 autogen: _unique.out 6298 6299- func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) 6300 variants: function 6301 dispatch: 6302 CPU: unique_dim_cpu 6303 CUDA: unique_dim_cuda 6304 tags: dynamic_output_shape 6305 autogen: unique_dim.out 6306 6307- func: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) 6308 variants: function 6309 dispatch: 6310 CPU: unique_consecutive_cpu 6311 CUDA: unique_consecutive_cuda 6312 MPS: unique_consecutive_mps 6313 tags: dynamic_output_shape 6314 autogen: unique_consecutive.out 6315 6316- func: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) 6317 variants: function 6318 dispatch: 6319 CPU: unique_dim_consecutive_cpu 6320 CUDA: unique_dim_consecutive_cuda 6321 MPS: unique_dim_consecutive_mps 6322 tags: dynamic_output_shape 6323 autogen: unique_dim_consecutive.out 6324 6325# _unique and _unique_dim are fragile and modifying them easily cause internal break 6326# the below operator is a temporary hack for adding return_counts support 6327# Please don't rely on these two operators, they will be removed soon 6328 6329- func: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) 6330 variants: function 6331 dispatch: 6332 CPU: _unique2_cpu 6333 CUDA: _unique2_cuda 6334 MPS: _unique2_mps 6335 tags: dynamic_output_shape 6336 autogen: _unique2.out 6337 6338- func: _unsafe_view(Tensor self, SymInt[] size) -> Tensor 6339 dispatch: 6340 CompositeExplicitAutograd: _unsafe_view 6341 autogen: _unsafe_view.out 6342 6343- func: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) 6344 variants: function, method 6345 device_check: NoCheck 6346 device_guard: False 6347 dispatch: 6348 CompositeExplicitAutograd: unsqueeze 6349 SparseCPU, SparseCUDA: unsqueeze_sparse 6350 QuantizedCPU, QuantizedCUDA: unsqueeze_quantized 6351 NestedTensorCPU, NestedTensorCUDA: unsqueeze_nested 6352 tags: core 6353 6354- func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) 6355 variants: method 6356 device_check: NoCheck 6357 device_guard: False 6358 tags: inplace_view 6359 dispatch: 6360 CompositeExplicitAutograd: unsqueeze_ 6361 6362- func: vander(Tensor x, int? N=None, bool increasing=False) -> Tensor 6363 6364- func: var(Tensor self, bool unbiased=True) -> Tensor 6365 device_check: NoCheck # TensorIterator 6366 variants: function, method 6367 cpp_no_default_args: ["unbiased"] 6368 6369- func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor 6370 device_check: NoCheck # TensorIterator 6371 variants: function, method 6372 tags: core 6373 cpp_no_default_args: ["unbiased"] 6374 6375- func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor 6376 device_check: NoCheck # TensorIterator 6377 variants: function, method 6378 dispatch: 6379 CPU, CUDA: var 6380 MPS: var_mps 6381 tags: core 6382 6383- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 6384 device_check: NoCheck # TensorIterator 6385 cpp_no_default_args: ["unbiased"] 6386 6387- func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) 6388 device_check: NoCheck # TensorIterator 6389 dispatch: 6390 CPU, CUDA: var_out 6391 6392- func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor 6393 device_check: NoCheck # TensorIterator 6394 variants: function, method 6395 cpp_no_default_args: ["unbiased"] 6396 6397- func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 6398 device_check: NoCheck # TensorIterator 6399 cpp_no_default_args: ["unbiased"] 6400 6401- func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor 6402 device_check: NoCheck # TensorIterator 6403 variants: function, method 6404 6405- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) 6406 device_check: NoCheck # TensorIterator 6407 variants: function 6408 6409- func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) 6410 device_check: NoCheck # TensorIterator 6411 variants: function 6412 cpp_no_default_args: ["unbiased"] 6413 6414- func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) 6415 device_check: NoCheck # TensorIterator 6416 variants: function 6417 cpp_no_default_args: ["unbiased"] 6418 6419- func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) 6420 device_check: NoCheck # TensorIterator 6421 variants: function 6422 dispatch: 6423 CPU, CUDA: var_mean 6424 MPS: var_mean_mps 6425 autogen: var_mean.correction_out 6426 6427- func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) 6428 device_check: NoCheck # TensorIterator 6429 variants: function 6430 cpp_no_default_args: ["unbiased"] 6431 6432- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) 6433 device_check: NoCheck # TensorIterator 6434 variants: function 6435 6436- func: view_as(Tensor(a) self, Tensor other) -> Tensor(a) 6437 variants: method 6438 device_check: NoCheck 6439 device_guard: False 6440 6441- func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor 6442 device_check: NoCheck # TensorIterator 6443 variants: function, method 6444 dispatch: 6445 CPU, CUDA, MPS: where 6446 tags: [core, pointwise] 6447 6448- func: where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 6449 device_check: NoCheck # TensorIterator 6450 dispatch: 6451 CPU, CUDA, MPS: where_self_out 6452 6453- func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor 6454 variants: function 6455 6456- func: where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor 6457 variants: function, method 6458 6459- func: where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor 6460 variants: function 6461 6462- func: where(Tensor condition) -> Tensor[] 6463 device_check: NoCheck # TensorIterator 6464 variants: function 6465 6466- func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor 6467 variants: function 6468 6469# VariableType::_weight_norm does not want to be given a gap in the autograd graph, 6470# so we don't define "dispatch" variants for it. 6471- func: _weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor 6472 variants: function 6473 6474- func: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) 6475 variants: function 6476 dispatch: 6477 CPU: weight_norm_cpu 6478 CUDA: weight_norm_cuda 6479 MPS: weight_norm_mps 6480 autogen: _weight_norm_interface.out 6481 6482- func: _weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) 6483 variants: function 6484 dispatch: 6485 CPU: weight_norm_backward_cpu 6486 CUDA: weight_norm_backward_cuda 6487 MPS: weight_norm_backward_mps 6488 autogen: _weight_norm_interface_backward.out 6489 6490- func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) 6491 variants: function 6492 6493- func: zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 6494 device_check: NoCheck 6495 device_guard: False 6496 dispatch: 6497 CompositeExplicitAutograd: zeros 6498 autogen: zeros.names_out 6499 6500- func: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 6501 dispatch: 6502 CPU: _efficientzerotensor 6503 CUDA: _efficientzerotensor_cuda 6504 MPS: _efficientzerotensor_mps 6505 Meta: _efficientzerotensor_meta_symint 6506 autogen: _efficientzerotensor.out 6507 6508- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 6509 dispatch: 6510 CompositeExplicitAutograd: zeros_symint 6511 6512- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) 6513 dispatch: 6514 CompositeExplicitAutograd: zeros_out 6515 SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out 6516 6517- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor 6518 dispatch: 6519 # NB: Although this composite mutates on the inside, it is 6520 # non-differentiable so NonFunctional doesn't apply 6521 CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: zeros_like 6522 autogen: zeros_like.out 6523 6524- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor 6525 variants: function 6526 dispatch: 6527 CPU: _standard_gamma_grad_cpu 6528 CUDA: _standard_gamma_grad_cuda 6529 autogen: _standard_gamma_grad.out 6530 6531- func: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor 6532 variants: function 6533 dispatch: 6534 CPU: _s_gamma_cpu 6535 CUDA: _s_gamma_cuda 6536 tags: nondeterministic_seeded 6537 autogen: _standard_gamma.out 6538 6539- func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor 6540 dispatch: 6541 CPU: _dirichlet_grad_cpu 6542 CUDA: _dirichlet_grad_cuda 6543 autogen: _dirichlet_grad.out 6544 6545- func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor 6546 tags: nondeterministic_seeded 6547 variants: function 6548 dispatch: 6549 CPU: _s_dirichlet_cpu 6550 CUDA: _s_dirichlet_cuda 6551 autogen: _sample_dirichlet.out 6552 6553- func: poisson(Tensor self, Generator? generator=None) -> Tensor 6554 device_check: NoCheck # TensorIterator 6555 dispatch: 6556 CPU: _s_poisson_cpu 6557 CUDA: _s_poisson_cuda 6558 tags: nondeterministic_seeded 6559 autogen: poisson.out 6560 6561- func: binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor 6562 device_check: NoCheck # TensorIterator 6563 dispatch: 6564 CPU: _s_binomial_cpu 6565 CUDA: _s_binomial_cuda 6566 tags: nondeterministic_seeded 6567 autogen: binomial.out 6568 6569# When more variants get ported to native, this dispatch will get more 6570# complicated 6571 6572- func: native_norm(Tensor self, Scalar p=2) -> Tensor 6573 dispatch: 6574 SparseCPU, SparseCUDA: norm_sparse 6575 autogen: native_norm.out 6576 6577- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor 6578 dispatch: 6579 SparseCPU, SparseCUDA: norm_sparse 6580 autogen: native_norm.ScalarOpt_dim_dtype_out 6581 6582- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) 6583 dispatch: 6584 CPU: _batch_norm_with_update_cpu 6585 CUDA: _batch_norm_with_update_cuda 6586 MPS: _batch_norm_with_update_mps 6587 MkldnnCPU: _batch_norm_with_update_mkldnn 6588 autogen: _batch_norm_with_update_functional 6589 6590- func: _batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) 6591 dispatch: 6592 CPU: _batch_norm_with_update_cpu_out 6593 CUDA: _batch_norm_with_update_cuda_out 6594 MPS: _batch_norm_with_update_mps_out 6595 6596- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) 6597 dispatch: 6598 CompositeExplicitAutograd: _batch_norm_no_update 6599 autogen: _batch_norm_no_update.out 6600 6601- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) 6602 dispatch: 6603 CPU: _new_batch_norm_backward_cpu 6604 CUDA: _new_batch_norm_backward_cuda 6605 MPS: _new_batch_norm_backward_mps 6606 MkldnnCPU: _new_batch_norm_backward_mkldnn 6607 6608# TODO: reduce signatures down to one when optional args is available 6609- func: _sparse_sum(Tensor self) -> Tensor 6610 6611- func: _sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor 6612 6613- func: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor 6614 dispatch: 6615 CompositeExplicitAutograd: _sparse_sum 6616 autogen: _sparse_sum.dim_out 6617 6618- func: _sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor 6619 6620- func: _sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor 6621 dispatch: 6622 SparseCPU: _sparse_sum_backward_cpu 6623 SparseCUDA: _sparse_sum_backward_cuda 6624 autogen: _sparse_sum_backward.out 6625 6626- func: _sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 6627 dispatch: 6628 SparseCsrCPU: _sparse_csr_sum_cpu 6629 SparseCsrCUDA: _sparse_csr_sum_cuda 6630 autogen: _sparse_csr_sum.dim_dtype_out 6631 6632- func: _sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 6633 dispatch: 6634 SparseCsrCPU: _sparse_csr_prod_cpu 6635 SparseCsrCUDA: _sparse_csr_prod_cuda 6636 autogen: _sparse_csr_prod.dim_dtype_out 6637 6638- func: _sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor 6639 python_module: sparse 6640 variants: function 6641 6642- func: _sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor 6643 python_module: sparse 6644 variants: function 6645 6646- func: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor 6647 python_module: sparse 6648 dispatch: 6649 SparseCPU: softmax_sparse_cpu 6650 SparseCUDA: softmax_sparse_cuda 6651 autogen: _sparse_softmax.out 6652 6653- func: _sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor 6654 dispatch: 6655 SparseCPU: softmax_backward_sparse_cpu 6656 SparseCUDA: softmax_backward_sparse_cuda 6657 autogen: _sparse_softmax_backward_data.out 6658 6659- func: _sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor 6660 python_module: sparse 6661 variants: function 6662 6663- func: _sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor 6664 python_module: sparse 6665 variants: function 6666 6667- func: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor 6668 python_module: sparse 6669 dispatch: 6670 SparseCPU: log_softmax_sparse_cpu 6671 SparseCUDA: log_softmax_sparse_cuda 6672 autogen: _sparse_log_softmax.out 6673 6674- func: _sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor 6675 dispatch: 6676 SparseCPU: log_softmax_backward_sparse_cpu 6677 SparseCUDA: log_softmax_backward_sparse_cuda 6678 autogen: _sparse_log_softmax_backward_data.out 6679 6680- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor 6681 python_module: sparse 6682 dispatch: 6683 CPU: spdiags 6684 autogen: _spdiags.out 6685 6686- func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor 6687 device_check: NoCheck # TensorIterator 6688 variants: function, method 6689 dispatch: 6690 CompositeExplicitAutograd: norm 6691 autogen: norm.ScalarOpt_dtype_out 6692 6693- func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor 6694 device_check: NoCheck # TensorIterator 6695 variants: function, method 6696 dispatch: 6697 CompositeExplicitAutograd: norm 6698 autogen: norm.Scalar_out 6699 6700- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor 6701 structured_delegate: norm.dtype_out 6702 device_check: NoCheck # TensorIterator 6703 variants: function, method 6704 dispatch: 6705 SparseCPU, SparseCUDA: sparse_dtype_norm 6706 6707- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor 6708 structured_delegate: norm.out 6709 device_check: NoCheck # TensorIterator 6710 variants: function, method 6711 dispatch: 6712 SparseCPU, SparseCUDA: sparse_norm 6713 6714- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) 6715 structured: True 6716 device_check: NoCheck # TensorIterator 6717 dispatch: 6718 CPU, CUDA: norm_dtype_out 6719 MPS: norm_dtype_out_mps 6720 6721- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 6722 structured: True 6723 device_check: NoCheck # TensorIterator 6724 dispatch: 6725 CPU, CUDA: norm_out 6726 MPS: norm_out_mps 6727 6728# These four redispatch in their implementation, so OK to be CompositeImplicitAutograd 6729- func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor 6730 device_check: NoCheck # TensorIterator 6731 variants: function, method 6732 6733- func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor 6734 device_check: NoCheck # TensorIterator 6735 variants: function, method 6736 6737- func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) 6738 device_check: NoCheck # TensorIterator 6739 6740- func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 6741 device_check: NoCheck # TensorIterator 6742 6743- func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) 6744 variants: method, function 6745 dispatch: 6746 CompositeExplicitAutograd: frexp 6747 tags: pointwise 6748 6749- func: frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) 6750 dispatch: 6751 CPU, CUDA: frexp_out 6752 tags: pointwise 6753 6754# Deprecated (v.1.12) 6755- func: frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor 6756 variants: function 6757 6758# Deprecated (v.1.12) 6759- func: frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 6760 variants: function 6761 6762# Deprecated (v.1.12) 6763- func: nuclear_norm(Tensor self, bool keepdim=False) -> Tensor 6764 variants: function 6765 6766# Deprecated (v.1.12) 6767- func: nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 6768 variants: function 6769 6770# Deprecated (v.1.12) 6771- func: nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor 6772 variants: function 6773 6774# Deprecated (v.1.12) 6775- func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 6776 variants: function 6777 6778- func: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor 6779 variants: function, method 6780 dispatch: 6781 CompositeExplicitAutograd: clone 6782 SparseCPU, SparseCUDA: clone_sparse 6783 SparseCsrCPU, SparseCsrCUDA: clone_sparse_compressed 6784 MkldnnCPU: mkldnn_clone 6785 QuantizedCPU, QuantizedCUDA: quantized_clone 6786 NestedTensorCPU, NestedTensorCUDA: clone_nested 6787 autogen: clone.out 6788 tags: [core, pointwise] 6789 6790- func: positive(Tensor(a) self) -> Tensor(a) 6791 variants: function, method 6792 tags: pointwise 6793 6794- func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) 6795 use_const_ref_for_mutable_tensors: True 6796 variants: function, method 6797 dispatch: 6798 CompositeExplicitAutograd: resize_as_ 6799 autogen: resize_as, resize_as.out 6800 tags: inplace_view 6801 6802- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) 6803 use_const_ref_for_mutable_tensors: True 6804 variants: function, method 6805 dispatch: 6806 SparseCPU, SparseCUDA: resize_as_sparse_ 6807 SparseCsrCPU, SparseCsrCUDA: resize_as_sparse_compressed_ 6808 autogen: resize_as_sparse, resize_as_sparse.out 6809 6810- func: zero_(Tensor(a!) self) -> Tensor(a!) 6811 device_check: NoCheck # TensorIterator 6812 variants: method, function 6813 dispatch: 6814 CPU, CUDA: zero_ 6815 MPS: zero_mps_ 6816 Meta: zero_meta_ 6817 SparseCPU, SparseCUDA, SparseMeta: zero_sparse_ 6818 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: zero_sparse_csr_ 6819 MkldnnCPU: mkldnn_zero_ 6820 NestedTensorCPU, NestedTensorCUDA: zero_nested_ 6821 autogen: zero, zero.out 6822 6823- func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 6824 device_check: NoCheck # TensorIterator 6825 structured: True 6826 structured_inherits: TensorIteratorBase 6827 dispatch: 6828 CPU, CUDA: sub_out 6829 MPS: sub_out_mps 6830 SparseCPU, SparseCUDA: sub_out_sparse 6831 tags: pointwise 6832 6833- func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 6834 device_check: NoCheck # TensorIterator 6835 variants: function, method 6836 structured_delegate: sub.out 6837 dispatch: 6838 SparseCPU, SparseCUDA: sub_sparse 6839 ZeroTensor: sub_zerotensor 6840 NestedTensorCPU, NestedTensorCUDA: NestedTensor_sub_Tensor 6841 tags: [core, pointwise] 6842 6843- func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) 6844 device_check: NoCheck # TensorIterator 6845 variants: method 6846 structured_delegate: sub.out 6847 dispatch: 6848 SparseCPU, SparseCUDA: sub_sparse_ 6849 tags: pointwise 6850# For C++ only, until we have conversion from C++ numbers to Tensor 6851 6852- func: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor 6853 device_check: NoCheck # TensorIterator 6854 variants: function, method 6855 dispatch: 6856 CompositeExplicitAutograd: sub 6857 tags: [core, pointwise] 6858 6859- func: sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) 6860 device_check: NoCheck # TensorIterator 6861 variants: method 6862 dispatch: 6863 CompositeExplicitAutograd: sub_ 6864 autogen: sub.Scalar_out 6865 tags: pointwise 6866# subtract, alias for sub 6867 6868- func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 6869 6870- func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 6871 variants: function, method 6872 6873- func: subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) 6874 variants: method 6875 6876# For C++ only, until we have conversion from C++ numbers to Tensor 6877- func: subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor 6878 variants: function, method 6879 6880- func: subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) 6881 variants: method 6882 6883- func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor 6884 device_check: NoCheck # TensorIterator 6885 variants: function 6886 dispatch: 6887 CPU, CUDA: rsub 6888 autogen: rsub.Tensor_out 6889 6890- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) 6891 structured: True 6892 structured_inherits: TensorIteratorBase 6893 device_check: NoCheck # TensorIterator 6894 dispatch: 6895 CPU, CUDA: heaviside_out 6896 tags: pointwise 6897 6898- func: heaviside(Tensor self, Tensor values) -> Tensor 6899 device_check: NoCheck # TensorIterator 6900 variants: function, method 6901 structured_delegate: heaviside.out 6902 tags: pointwise 6903 6904- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) 6905 device_check: NoCheck # TensorIterator 6906 variants: method 6907 structured_delegate: heaviside.out 6908 6909# For C++ only, until we have conversion from C++ numbers to Tensor 6910- func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor 6911 device_check: NoCheck # TensorIterator 6912 variants: function 6913 dispatch: 6914 CompositeExplicitAutograd: rsub 6915 autogen: rsub.Scalar_out 6916 6917# Functionally the same as addmm, but we give it a different derivative formula 6918# that doesn't propagate gradients to non-present entries on sparse. 6919 tags: pointwise 6920- func: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 6921 python_module: sparse 6922 dispatch: 6923 CompositeExplicitAutograd: _sparse_addmm 6924 autogen: _sparse_addmm.out 6925 6926- func: sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 6927 python_module: sparse 6928 dispatch: 6929 SparseCsrCUDA: sparse_sampled_addmm_out_sparse_csr_cuda 6930 SparseCsrCPU: sparse_sampled_addmm_out_sparse_csr_cpu 6931 6932- func: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 6933 python_module: sparse 6934 dispatch: 6935 SparseCsrCUDA: sparse_sampled_addmm_sparse_csr_cuda 6936 SparseCsrCPU: sparse_sampled_addmm_sparse_csr_cpu 6937 6938- func: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor) 6939 python_module: sparse 6940 dispatch: 6941 SparseCsrCPU: _sparse_mm_reduce_impl_sparse_csr_cpu 6942 6943- func: _sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor) 6944 python_module: sparse 6945 dispatch: 6946 SparseCsrCPU: _sparse_mm_reduce_impl_backward_sparse_csr_cpu 6947 6948- func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 6949 structured: True 6950 dispatch: 6951 CPU: addmm_out_cpu 6952 CUDA: addmm_out_cuda 6953 MPS: addmm_out_mps 6954 SparseCPU: addmm_out_sparse_dense_cpu 6955 SparseCUDA: addmm_out_sparse_dense_cuda 6956 SparseCsrCPU: addmm_out_sparse_compressed_cpu 6957 SparseCsrCUDA: addmm_out_sparse_compressed_cuda 6958 6959- func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 6960 structured_delegate: addmm.out 6961 variants: function, method 6962 dispatch: 6963 SparseCPU: addmm_sparse_dense_cpu 6964 SparseCUDA: addmm_sparse_dense_cuda 6965 SparseCsrCPU, SparseCsrCUDA: addmm_sparse_compressed_dense 6966 tags: core 6967 6968- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) 6969 structured_delegate: addmm.out 6970 variants: method 6971 dispatch: 6972 # Warning! For whatever reason, the inplace sparse addmm is NON 6973 # broadcasting 6974 SparseCPU: s_addmm_sparse_dense_cpu_ 6975 SparseCUDA: s_addmm_sparse_dense_cuda_ 6976 6977- func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!) 6978 structured: True 6979 dispatch: 6980 CPU: addmm_activation_out_cpu 6981 CUDA: addmm_activation_out_cuda 6982 6983- func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor 6984 structured_delegate: _addmm_activation.out 6985 variants: function, method 6986 6987- func: _scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False) -> (Tensor, Tensor) 6988 variants: function 6989 dispatch: 6990 CUDA: _scaled_mm_cuda 6991 6992- func: _scaled_mm.out(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False, Tensor(a!) out, Tensor(b!) out_amax) -> (Tensor(a!), Tensor(b!)) 6993 variants: function 6994 dispatch: 6995 CUDA: _scaled_mm_out_cuda 6996 6997# NOTE [ Sparse: autograd and API ] 6998# 6999# 7000# Sparse Tensor Constructors 7001# ~~~~~~~~~~~~~~~~~~~~~~~~~~ 7002# 7003# The API entry points to sparse tensor construction should be 7004# `sparse_coo tensor` and `_sparse_coo_tensor_unsafe`. Depending on whether the 7005# indices and values tensors are given, they eventually dispatch to either 7006# `sparse_coo_tensor_with_dims` or `sparse_coo_tensor_with_dims_and_tensors`. 7007# 7008# The autograd support for ctor is implement on `sparse_coo_tensor_with_dims_and_tensors`. 7009# 7010# The API methods `sparse_coo tensor` and `_sparse_coo_tensor_unsafe` 7011# **must not** have specific type dispatches because otherwise codegen will 7012# consider them as abstract methods (see Note [Abstract ATen methods]), dispatch 7013# using **Tensor** type, and thus lose autograd tracking on the actual method 7014# they dispatch to, e.g., `sparse_coo_tensor_with_dims_and_tensors`. 7015# 7016# 7017# Sparse Methods API Design 7018# ~~~~~~~~~~~~~~~~~~~~~~~~~ 7019# 7020# Goals: 1. Flexible API for users to write custom sparse ops 7021# 2. ctor and member accessor with autograd support 7022# 7023# To achieve 1, we need to provide a set of *dangerous* APIs (dangerous in the 7024# sense that misusing them will break sparse tensor invariant and may out in 7025# unexpected behavior, e.g., crash). These methods are all prefixed with 7026# underscore "_" to indicate that they should be used with care. We provide: 7027# 7028# + `_indices()`: returns the *raw* indices within the sparse tensor (not just 7029# sharing storage). Any inplace operation will change the 7030# actual indices, including t_, set_, as_strided_, resize_, 7031# etc. 7032# + `_values()`: returns the *raw* values within the sparse tensor. Similar 7033# semantics as `_indices()` 7034# + `_nnz()`: returns the number of non-zero entries. This will always be 7035# determined by the shapes of indices and values. 7036# + `_coalesced_(bool)`: inplace sets whether the tensor is coalesced, and 7037# returns itself. 7038# 7039# These methods are very useful in writing new operations, e.g., a custom 7040# autograd Function. 7041# 7042# We also provide other public *safe* APIs: 7043# + `indices()`: returns a **view** of the indices tensor if the sparse tensor 7044# is **coalesced**. 7045# + `values()`: returns a **view** of the values tensor if the containing 7046# sparse tensor is **coalesced**. 7047# + `sparse_dim()`: number of sparse dimensions 7048# + `dense_dim()`: number of dense dimensions 7049# + `is_coalesced()`: whether the sparse tensor is coalesced 7050# 7051# `_indices()` and `_values()` should returns the raw indices and values dense 7052# tensors within a sparse tensor. They can be quite unsafe with inplace 7053# operations like `t_()`, and exposes uncoalesced indices and values. The public 7054# recommended API is `indices()` and `values()`, both of which first check that 7055# the tensor is coalesced and return views on those tensors. 7056# 7057# 7058# Autograd Support 7059# ~~~~~~~~~~~~~~~~ 7060# 7061# Autograd is supported on `values()` and sparse tensor ctor with indices and 7062# values tensors. E.g., `torch.sparse_coo_tensor(i, v).values().sum()` is 7063# differentiable w.r.t. `v`. 7064# 7065# NB: The `values()` and `_values()` operators are special in that they are 7066# layout-aware, i.e., the output depends not just on the data it represents, but 7067# also on the input layout details (in this case, the `indices` tensor). See 7068# NOTE [ as_strided Backward and layout-aware/agnostic autograd ] in Functions.cpp 7069# for discussion on layout-aware vs layout-agnostic autograd. Since PyTorch ops 7070# operate in the layout-agnostic mode, similar to `as_strided`, backward of 7071# these two operators need to consider them in a layout-agnostic way: 7072# + `values()`: 7073# Input is coalesced. 7074# We just pretend having `input.indices()` as an additional argument 7075# `input_indices`, then forward is similar to 7076# `input.to(kStrided).index_select(input_indices)` regardless of the layout. 7077# Note that `values()` normally is layout-aware even if we constrain 7078# ourselves on sparse inputs since it may include all zeros values entries 7079# as "present" entries. 7080# + `_values()`: 7081# Input may be uncoalesced. 7082# It is not straightforward to construct a layout-agnostic version because 7083# duplicate indices entries may exist and additional parameterization is 7084# needed to distribute the value into different values entries. Furthermore, 7085# this op is intended to provide ways to write custom sparse ops, rather 7086# than being used in autograd graph, so it is marked as *non-differentiable* 7087# in derivatives.yaml. 7088# 7089# Before reading the following, see NOTE [ Autograd Variable Views ] in 7090# variable.h for details on views that are tracked by autograd, and views that 7091# are not. 7092# 7093# Moreover, these methods return tensors that share storage with inputs, so we 7094# mark these methods as view ops to support autograd history tracking. 7095# The sparse tensor ctor output should technically be view of both input indices 7096# and values tensors, but currently we only support setting as view of a single 7097# Variable, so it is only view of the values tensor. 7098# TODO: clone indices in sparse tensor ctor. 7099# 7100# For other methods that return outputs that share storage with inputs, i.e., 7101# `indices()` and `_indices()`. We mark their outputs as non-differentiable, so 7102# the view relation is not tracked by autograd, but the version counter is still 7103# shared. In other words, their outputs are non-differentiable views of the 7104# sparse tensor. 7105# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given 7106# the default would never make sense. 7107 7108- func: _sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7109 dispatch: 7110 CompositeExplicitAutograd: sparse_compressed_tensor_with_dims 7111 7112- func: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7113 dispatch: 7114 CompositeExplicitAutograd: sparse_compressed_tensor 7115 7116- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7117- func: sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7118- func: sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7119- func: sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7120 7121- func: sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7122 dispatch: 7123 CompositeExplicitAutograd: sparse_compressed_tensor 7124- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7125- func: sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7126- func: sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7127- func: sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7128 7129- func: _sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 7130 dispatch: 7131 CompositeImplicitAutograd: _sparse_compressed_tensor_unsafe_symint 7132 7133- func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 7134- func: _sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 7135- func: _sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 7136- func: _sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 7137 7138- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7139 dispatch: 7140 CompositeExplicitAutograd: sparse_coo_tensor 7141 autogen: sparse_coo_tensor.size_out 7142 7143- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor 7144 7145- func: sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor 7146 7147- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor 7148 dispatch: 7149 CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint 7150 7151- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None) -> () 7152 7153- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout) -> () 7154- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> () 7155- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> () 7156- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> () 7157- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size) -> () 7158 7159- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor 7160 dispatch: 7161 SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse 7162 autogen: _sparse_coo_tensor_with_dims.out 7163 7164- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor 7165 dispatch: 7166 SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint 7167 autogen: _sparse_coo_tensor_with_dims_and_tensors.out 7168 7169- func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) 7170 use_const_ref_for_mutable_tensors: True 7171 variants: method 7172 dispatch: 7173 SparseCPU, SparseCUDA, SparseMeta: sparse_resize_ 7174 autogen: sparse_resize, sparse_resize.out 7175 7176- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) 7177 use_const_ref_for_mutable_tensors: True 7178 variants: method 7179 dispatch: 7180 SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_ 7181 autogen: sparse_resize_and_clear, sparse_resize_and_clear.out 7182 7183- func: sparse_mask(Tensor self, Tensor mask) -> Tensor 7184 variants: method 7185 dispatch: 7186 SparseCPU, SparseCUDA: sparse_mask 7187 SparseCsrCPU, SparseCsrCUDA: sparse_mask_sparse_compressed 7188 autogen: sparse_mask.out 7189 7190- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor 7191 variants: method 7192 dispatch: 7193 SparseCPU, SparseCUDA: sparse_mask_projection 7194 autogen: _sparse_mask_projection.out 7195 7196- func: _to_cpu(Tensor[] tensors) -> Tensor[] 7197 variants: function 7198 7199- func: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor 7200 variants: method 7201 7202# Special case of to_dense with custom derivative 7203- func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor 7204 variants: method 7205 dispatch: 7206 SparseCPU, SparseCUDA: sparse_to_dense 7207 SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_dense 7208 MkldnnCPU: mkldnn_to_dense 7209 autogen: _to_dense.out 7210 7211- func: to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor 7212 7213- func: sparse_dim(Tensor self) -> int 7214 variants: method 7215 dispatch: 7216 SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse 7217 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr 7218 CompositeExplicitAutograd: sparse_dim_default 7219 device_check: NoCheck 7220 device_guard: False 7221 7222# legacy method 7223- func: _dimI(Tensor self) -> int 7224 variants: method 7225 dispatch: 7226 SparseCPU, SparseCUDA: sparse_dim_sparse 7227 device_check: NoCheck 7228 device_guard: False 7229 7230- func: dense_dim(Tensor self) -> int 7231 variants: method 7232 dispatch: 7233 SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse 7234 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: dense_dim_sparse_csr 7235 CompositeExplicitAutograd: dense_dim_default 7236 device_check: NoCheck 7237 device_guard: False 7238 7239# legacy method 7240- func: _dimV(Tensor self) -> int 7241 variants: method 7242 dispatch: 7243 SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse 7244 device_check: NoCheck 7245 device_guard: False 7246 7247- func: _nnz(Tensor self) -> int 7248 variants: method 7249 dispatch: 7250 SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse 7251 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr 7252 device_check: NoCheck 7253 device_guard: False 7254 7255# NOTE: [ coalesce autograd ] 7256# coalesce returns self directly for already coalesced sparse tensors. 7257# This means coalesce cannot have a derivative registered, otherwise it creates 7258# circular references in the autograd graph (see gh-52874). 7259# Instead, the derivative is registered on the slow-path "_coalesce" 7260- func: coalesce(Tensor(a) self) -> Tensor(a) 7261 variants: method 7262 7263- func: _coalesce(Tensor self) -> Tensor 7264 dispatch: 7265 SparseCPU: _coalesce_sparse_cpu 7266 SparseCUDA: _coalesce_sparse_cuda 7267 autogen: _coalesce.out 7268 7269- func: is_coalesced(Tensor self) -> bool 7270 variants: method 7271 dispatch: 7272 SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse 7273 CompositeExplicitAutograd: is_coalesced_default 7274 device_check: NoCheck 7275 device_guard: False 7276 7277- func: _indices(Tensor(a) self) -> Tensor(a) 7278 variants: method 7279 dispatch: 7280 SparseCPU, SparseCUDA, SparseMeta: _indices_sparse 7281 device_check: NoCheck 7282 device_guard: False 7283 7284- func: _values(Tensor(a) self) -> Tensor(a) 7285 variants: method 7286 dispatch: 7287 SparseCPU, SparseCUDA, SparseMeta: _values_sparse 7288 device_check: NoCheck 7289 device_guard: False 7290 7291# This method doesn't do any check but only directly sets the flag. So it can be 7292# a bit unsafe. Similar to _indices and _values, this is useful for implementing 7293# custom sparse operations in Python/C++ extension. 7294- func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) 7295 variants: method 7296 dispatch: 7297 SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_ 7298 device_check: NoCheck 7299 device_guard: False 7300 autogen: _coalesced, _coalesced.out 7301 7302- func: indices(Tensor(a) self) -> Tensor(a) 7303 variants: method 7304 dispatch: 7305 SparseCPU, SparseCUDA, SparseMeta: indices_sparse 7306 CompositeExplicitAutograd: indices_default 7307 device_check: NoCheck 7308 device_guard: False 7309 7310- func: values(Tensor(a) self) -> Tensor(a) 7311 variants: method 7312 dispatch: 7313 SparseCPU, SparseCUDA, SparseMeta: values_sparse 7314 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: values_sparse_csr 7315 NestedTensorCPU, NestedTensorCUDA: values_nested 7316 CompositeExplicitAutograd: values_default 7317 device_check: NoCheck 7318 device_guard: False 7319 7320- func: crow_indices(Tensor(a) self) -> Tensor(a) 7321 variants: method 7322 dispatch: 7323 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: crow_indices_sparse_csr 7324 CompositeExplicitAutograd: crow_indices_default 7325 device_check: NoCheck 7326 device_guard: False 7327 7328- func: col_indices(Tensor(a) self) -> Tensor(a) 7329 variants: method 7330 dispatch: 7331 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: col_indices_sparse_csr 7332 CompositeExplicitAutograd: col_indices_default 7333 device_check: NoCheck 7334 device_guard: False 7335 7336- func: ccol_indices(Tensor(a) self) -> Tensor(a) 7337 variants: method 7338 dispatch: 7339 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ccol_indices_sparse_csr 7340 CompositeExplicitAutograd: ccol_indices_default 7341 device_check: NoCheck 7342 device_guard: False 7343 7344- func: row_indices(Tensor(a) self) -> Tensor(a) 7345 variants: method 7346 dispatch: 7347 SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: row_indices_sparse_csr 7348 CompositeExplicitAutograd: row_indices_default 7349 device_check: NoCheck 7350 device_guard: False 7351 7352- func: hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) 7353 dispatch: 7354 SparseCPU: hspmm_out_sparse_cpu 7355 SparseCUDA: hspmm_out_sparse_cuda 7356 7357- func: hspmm(Tensor mat1, Tensor mat2) -> Tensor 7358 dispatch: 7359 SparseCPU: hspmm_sparse_cpu 7360 SparseCUDA: hspmm_sparse_cuda 7361 7362- func: copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) 7363 device_check: NoCheck # Allows copy into different device 7364 variants: function 7365 dispatch: 7366 SparseCPU, SparseCUDA, SparseMeta: copy_sparse_ 7367 autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out 7368 7369# By adding the AutogradNestedTensor this makes this function CompositeImplicit-like for nested tensors 7370- func: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] 7371 variants: function, method 7372 dispatch: 7373 CompositeExplicitAutograd: unbind 7374 NestedTensorCPU, NestedTensorCUDA: NestedTensor_unbind 7375 7376- func: unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[] 7377 variants: function, method 7378 7379- func: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor 7380 variants: method 7381 7382# Special case of to_sparse.sparse_dim with custom derivative 7383- func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor 7384 variants: method 7385 dispatch: 7386 CPU, CUDA: dense_to_sparse 7387 SparseCPU, SparseCUDA: sparse_coo_to_sparse 7388 SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse 7389 autogen: _to_sparse.sparse_dim_out 7390 7391- func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor 7392 variants: method 7393 7394# Special case of to_sparse with custom derivative 7395- func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor 7396 variants: method 7397 dispatch: 7398 CPU, CUDA: dense_to_sparse 7399 SparseCPU, SparseCUDA: sparse_coo_to_sparse 7400 SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse 7401 autogen: _to_sparse.out 7402 7403- func: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor 7404 variants: method 7405 7406# Special case of to_sparse_csr with custom derivative 7407- func: _to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor 7408 variants: method 7409 dispatch: 7410 CPU, CUDA: dense_to_sparse_csr 7411 SparseCPU, SparseCUDA: coo_to_sparse_csr 7412 SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse_csr 7413 autogen: _to_sparse_csr.out 7414 7415- func: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor 7416 variants: method 7417 7418# Special case of to_sparse_csc with custom derivative 7419- func: _to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor 7420 variants: method 7421 dispatch: 7422 CPU, CUDA: dense_to_sparse_csc 7423 SparseCPU, SparseCUDA: coo_to_sparse_csc 7424 SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse_csc 7425 autogen: _to_sparse_csc.out 7426 7427- func: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor 7428 variants: method 7429 7430# Special case of to_sparse_bsr with custom derivative 7431- func: _to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor 7432 variants: method 7433 dispatch: 7434 CPU, CUDA: dense_to_sparse_bsr 7435 SparseCPU, SparseCUDA: coo_to_sparse_bsr 7436 SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse_bsr 7437 autogen: _to_sparse_bsr.out 7438 7439- func: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor 7440 variants: method 7441 7442# Special case of to_sparse_bsc with custom derivative 7443- func: _to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor 7444 variants: method 7445 dispatch: 7446 CPU, CUDA: dense_to_sparse_bsc 7447 SparseCPU, SparseCUDA: coo_to_sparse_bsc 7448 SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_sparse_bsc 7449 autogen: _to_sparse_bsc.out 7450 7451- func: _to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor) 7452 variants: function 7453 dispatch: 7454 CUDA: _to_sparse_semi_structured 7455 7456- func: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor 7457 variants: method 7458 dispatch: 7459 CPU: dense_to_mkldnn 7460 autogen: to_mkldnn.out 7461 7462- func: mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor 7463 variants: function 7464 python_module: nn 7465 dispatch: 7466 MkldnnCPU: mkldnn_reorder_conv2d_weight 7467 autogen: mkldnn_reorder_conv2d_weight.out 7468 7469- func: mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor 7470 variants: function 7471 python_module: nn 7472 dispatch: 7473 MkldnnCPU: mkldnn_reorder_conv3d_weight 7474 autogen: mkldnn_reorder_conv3d_weight.out 7475 7476- func: to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor 7477 7478- func: quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor 7479 variants: function 7480 dispatch: 7481 CPU, CUDA: quantize_per_tensor_dynamic 7482 autogen: quantize_per_tensor_dynamic.out 7483 7484- func: quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor 7485 variants: function 7486 dispatch: 7487 CPU, CUDA: quantize_per_tensor 7488 autogen: quantize_per_tensor.out 7489 7490- func: quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor 7491 variants: function 7492 dispatch: 7493 CPU, CUDA: quantize_per_tensor_tensor_qparams 7494 autogen: quantize_per_tensor.tensor_qparams_out 7495 7496- func: quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[] 7497 variants: function 7498 dispatch: 7499 CPU: quantize_per_tensor_list_cpu 7500 autogen: quantize_per_tensor.tensors_out 7501 7502- func: quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor 7503 variants: function 7504 dispatch: 7505 CPU, CUDA: quantize_per_channel 7506 autogen: quantize_per_channel.out 7507 7508- func: dequantize.self(Tensor self) -> Tensor 7509 variants: function, method 7510 dispatch: 7511 CPU, CUDA: dequantize_cpu_or_cuda 7512 QuantizedCPU, QuantizedCUDA: dequantize_quantized 7513 autogen: dequantize.self_out 7514 7515- func: dequantize.tensors(Tensor[] tensors) -> Tensor[] 7516 variants: function 7517 dispatch: 7518 QuantizedCPU: dequantize_tensors_quantized_cpu 7519 autogen: dequantize.tensors_out 7520 7521- func: q_scale(Tensor self) -> float 7522 variants: function, method 7523 dispatch: 7524 QuantizedCPU, QuantizedCUDA: q_scale_quant 7525 7526- func: q_zero_point(Tensor self) -> int 7527 variants: function, method 7528 dispatch: 7529 QuantizedCPU, QuantizedCUDA: q_zero_point_quant 7530 7531- func: q_per_channel_scales(Tensor self) -> Tensor 7532 variants: function, method 7533 dispatch: 7534 QuantizedCPU, QuantizedCUDA: q_per_channel_scales 7535 autogen: q_per_channel_scales.out 7536 7537- func: q_per_channel_zero_points(Tensor self) -> Tensor 7538 variants: function, method 7539 dispatch: 7540 QuantizedCPU, QuantizedCUDA: q_per_channel_zero_points 7541 autogen: q_per_channel_zero_points.out 7542 7543- func: q_per_channel_axis(Tensor self) -> int 7544 variants: function, method 7545 dispatch: 7546 QuantizedCPU, QuantizedCUDA: q_per_channel_axis 7547 7548- func: int_repr(Tensor self) -> Tensor 7549 device_check: NoCheck # TensorIterator 7550 variants: function, method 7551 dispatch: 7552 QuantizedCPU: int_repr_quantized_cpu 7553 QuantizedCUDA: int_repr_quantized_cuda 7554 autogen: int_repr.out 7555 7556- func: _make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor 7557 dispatch: 7558 CPU: make_per_tensor_quantized_tensor_cpu 7559 CUDA: make_per_tensor_quantized_tensor_cuda 7560 autogen: _make_per_tensor_quantized_tensor.out 7561 7562- func: _make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor 7563 dispatch: 7564 CPU: make_per_channel_quantized_tensor_cpu 7565 CUDA: make_per_channel_quantized_tensor_cuda 7566 autogen: _make_per_channel_quantized_tensor.out 7567 7568- func: qscheme(Tensor self) -> QScheme 7569 variants: method 7570 dispatch: 7571 QuantizedCPU, QuantizedCUDA: qscheme_quant 7572 7573- func: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor 7574 device_check: NoCheck # TensorIterator 7575 variants: function 7576 7577- func: fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor 7578 device_check: NoCheck # TensorIterator 7579 variants: function 7580 7581- func: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) 7582 variants: function 7583 dispatch: 7584 CPU, CUDA: fake_quantize_per_tensor_affine_cachemask 7585 autogen: fake_quantize_per_tensor_affine_cachemask.out 7586 7587- func: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask) 7588 variants: function 7589 dispatch: 7590 CPU, CUDA: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams 7591 autogen: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out 7592 7593- func: fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor 7594 variants: function 7595 7596- func: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor 7597 variants: function 7598 dispatch: 7599 CPU, CUDA: _fake_quantize_learnable_per_tensor_affine 7600 autogen: _fake_quantize_learnable_per_tensor_affine.out 7601 7602- func: _fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor) 7603 variants: function 7604 dispatch: 7605 CPU, CUDA: _fake_quantize_learnable_per_tensor_affine_backward 7606 7607- func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor 7608 device_check: NoCheck # TensorIterator 7609 variants: function 7610 7611- func: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) 7612 variants: function 7613 dispatch: 7614 CPU, CUDA: fake_quantize_per_channel_affine_cachemask 7615 autogen: fake_quantize_per_channel_affine_cachemask.out 7616 7617- func: fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor 7618 variants: function 7619 7620- func: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor 7621 variants: function 7622 dispatch: 7623 CPU, CUDA: _fake_quantize_learnable_per_channel_affine 7624 autogen: _fake_quantize_learnable_per_channel_affine.out 7625 7626- func: _fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor) 7627 variants: function 7628 dispatch: 7629 CPU, CUDA: _fake_quantize_learnable_per_channel_affine_backward 7630 7631- func: fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor 7632 variants: function 7633 7634- func: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) 7635 dispatch: 7636 CPU: fused_moving_avg_obs_fake_quant_cpu 7637 CUDA: fused_moving_avg_obs_fake_quant_cuda 7638 autogen: _fused_moving_avg_obs_fq_helper_functional, _fused_moving_avg_obs_fq_helper.out 7639 7640- func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int) 7641 variants: function 7642 7643- func: _saturate_weight_to_fp16(Tensor weight) -> Tensor 7644 variants: function 7645 7646- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor) 7647 variants: function 7648 7649- func: _autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a) 7650 variants: method 7651 device_guard: False 7652 7653- func: _autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a) 7654 variants: method 7655 device_guard: False 7656 7657- func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 7658 device_check: NoCheck 7659 device_guard: False 7660 dispatch: 7661 CompositeExplicitAutograd: _to_copy 7662 NestedTensorCPU, NestedTensorCUDA: _to_copy_nested 7663 autogen: _to_copy.out 7664 tags: core 7665 7666# to(Device) must not exist because all constructors of Device also works for 7667# TensorOptions. Otherwise, an ambiguity error is thrown. 7668# See NOTE [ TensorOptions Constructors ]. 7669- func: to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) 7670 variants: method 7671 device_check: NoCheck 7672 device_guard: False 7673 7674- func: to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) 7675 variants: method 7676 device_check: NoCheck 7677 device_guard: False 7678 7679- func: to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) 7680 variants: method 7681 device_check: NoCheck 7682 device_guard: False 7683 7684- func: to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) 7685 variants: method 7686 device_check: NoCheck 7687 device_guard: False 7688 7689- func: meshgrid(Tensor[] tensors) -> Tensor[] 7690 7691# TODO: Two weeks after this lands, combine these two overloads, 7692# making "indexing" optional. These are temporarily distinct for 7693# forward-compatibility reasons. 7694- func: meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[] 7695 7696- func: cartesian_prod(Tensor[] tensors) -> Tensor 7697 variants: function 7698 7699- func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor 7700 variants: function 7701 7702- func: item(Tensor self) -> Scalar 7703 tags: data_dependent_output 7704 variants: method 7705 7706- func: result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType 7707 variants: function 7708 7709- func: result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType 7710 variants: function 7711 7712- func: result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType 7713 variants: function 7714 7715- func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType 7716 7717- func: can_cast(ScalarType from_, ScalarType to) -> bool 7718 variants: function 7719 7720- func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType 7721 variants: function 7722 7723# NB: Does NOT check precondition that numel == 1 7724- func: _local_scalar_dense(Tensor self) -> Scalar 7725 tags: [core, data_dependent_output] 7726 dispatch: 7727 CPU: _local_scalar_dense_cpu 7728 CUDA: _local_scalar_dense_cuda 7729 MPS: _local_scalar_dense_mps 7730 variants: function 7731 7732# MPS LSTM implementation 7733 7734- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) 7735 dispatch: 7736 MPS: _lstm_mps 7737 autogen: _lstm_mps.out 7738 tags: nondeterministic_seeded 7739 7740- func: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[]) 7741 dispatch: 7742 MPS: lstm_mps_backward 7743 autogen: lstm_mps_backward.out 7744 7745 7746# Fused RNN kernels 7747- func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) 7748 dispatch: 7749 CUDA: _thnn_fused_lstm_cell_cuda 7750 autogen: _thnn_fused_lstm_cell.out 7751 7752# NB: The composite version of this function below is a simple wrapper that duplicates some of the outputs 7753# It is necessary to avoid triggering TensorImpl use count checks in debug mode 7754# NB: this is function is NOT differentiable 7755- func: _thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor) 7756 dispatch: 7757 CUDA: _thnn_fused_lstm_cell_backward_impl_cuda 7758 autogen: _thnn_fused_lstm_cell_backward_impl.out 7759 7760- func: _thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) 7761 7762- func: _thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor) 7763 7764- func: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) 7765 dispatch: 7766 CUDA: _thnn_fused_gru_cell_cuda 7767 autogen: _thnn_fused_gru_cell.out 7768 7769- func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) 7770 dispatch: 7771 CUDA: _thnn_fused_gru_cell_backward_cuda 7772 autogen: _thnn_fused_gru_cell_backward.out 7773 7774- func: _thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) 7775 7776# RNN cells and layers 7777- func: lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) 7778 tags: nondeterministic_seeded 7779 7780- func: lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) 7781 tags: nondeterministic_seeded 7782 7783- func: gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) 7784 tags: nondeterministic_seeded 7785 7786- func: gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) 7787 tags: nondeterministic_seeded 7788 7789- func: rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) 7790 tags: nondeterministic_seeded 7791 7792- func: rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) 7793 tags: nondeterministic_seeded 7794 7795- func: rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) 7796 tags: nondeterministic_seeded 7797 7798- func: rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) 7799 tags: nondeterministic_seeded 7800 7801- func: lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor) 7802 7803- func: gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor 7804 7805- func: rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor 7806 7807- func: rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor 7808 7809# Quantized RNN layer registration has been moved to C10 dispatch in `RNN.cpp` 7810 7811# Quantized RNN layers 7812# - func: quantized_lstm(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor) 7813 7814 7815# - func: quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor) 7816 7817 7818# Quantized GRU layers 7819 7820# - func: quantized_gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) 7821# 7822 7823# - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) 7824# 7825 7826# Quantized RNN cells 7827- func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor) 7828 7829- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor 7830 7831- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor 7832 7833- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor 7834 7835# PackedSequence utilities 7836- func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) 7837 dispatch: 7838 CompositeExplicitAutograd: _pack_padded_sequence 7839 autogen: _pack_padded_sequence.out 7840 7841- func: _pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor 7842 dispatch: 7843 CompositeImplicitAutograd: _pack_padded_sequence_backward_symint 7844 7845- func: _pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor) 7846 7847# wrappers for legacy TH methods 7848 7849- func: set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!) 7850 variants: method 7851 device_check: NoCheck 7852 device_guard: False 7853 dispatch: 7854 CPU, CUDA, Meta, MPS: set_ 7855 autogen: set.source_Storage, set.source_Storage_out 7856 tags: inplace_view 7857 7858- func: set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) 7859 variants: method 7860 device_check: NoCheck 7861 device_guard: False 7862 dispatch: 7863 CPU: set_storage_cpu_ 7864 Meta: set_storage_meta__symint 7865 CUDA: set_storage_cuda_ 7866 MPS: set_storage_mps_ 7867 QuantizedCPU, QuantizedCUDA: set_storage_quantized_ 7868 autogen: set.source_Storage_storage_offset, set.source_Storage_storage_offset_out 7869 tags: inplace_view 7870 7871- func: set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) 7872 variants: method 7873 device_check: NoCheck 7874 device_guard: False 7875 dispatch: 7876 CompositeImplicitAutograd: set__symint 7877 tags: inplace_view 7878 7879- func: set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) 7880 variants: method 7881 device_check: NoCheck 7882 device_guard: False 7883 dispatch: 7884 CPU, CUDA, Meta, MPS: set_tensor_ 7885 autogen: set.source_Tensor, set.source_Tensor_out 7886 tags: inplace_view 7887 7888- func: set_(Tensor(a!) self) -> Tensor(a!) 7889 variants: method 7890 dispatch: 7891 CPU: set_cpu_ 7892 CUDA: set_cuda_ 7893 Meta: set_meta_ 7894 MPS: set_mps_ 7895 autogen: set, set.out 7896 tags: inplace_view 7897 7898# Not making it CompositeImplicitAutograd because lift 7899# should be a primitive w.r.t. functorch 7900 7901# TODO: this should have a view annotation 7902# TODO: shouldn't be a method 7903- func: lift(Tensor self) -> Tensor 7904 dispatch: 7905 CompositeExplicitAutograd: lift 7906 autogen: lift.out 7907 7908# lift_fresh is called with an argument that is guaranteed to be 7909# fresh (i.e., newly allocated). This is ONLY called from a 7910# torch.tensor call; if you FX trace a lift_fresh, you are obligated 7911# to convert this into a lift_fresh_copy (because FX will violate the 7912# freshness invariant when tracing). 7913- func: lift_fresh(Tensor(a) self) -> Tensor(a) 7914 dispatch: 7915 CompositeExplicitAutograd: lift_fresh 7916 7917# Like lift, but it clones the input. 7918- func: lift_fresh_copy(Tensor self) -> Tensor 7919 tags: view_copy 7920 dispatch: 7921 CompositeExplicitAutogradNonFunctional: lift_fresh_copy 7922 autogen: lift_fresh_copy.out 7923 7924- func: is_set_to(Tensor self, Tensor tensor) -> bool 7925 variants: method 7926 device_check: NoCheck 7927 device_guard: False 7928 dispatch: 7929 CPU, CUDA, MPS: is_set_to 7930 7931- func: masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) 7932 device_check: NoCheck # TensorIterator 7933 variants: method 7934 dispatch: 7935 CPU: masked_fill__cpu 7936 CUDA: masked_fill__cuda 7937 QuantizedCPU: masked_fill__quantized_cpu 7938 QuantizedCUDA: masked_fill__quantized_cuda 7939 MPS: masked_fill__mps 7940 autogen: masked_fill.Scalar_out 7941 7942- func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor 7943 device_check: NoCheck # TensorIterator 7944 variants: function, method 7945 dispatch: 7946 CompositeExplicitAutograd: masked_fill 7947 NestedTensorCPU, NestedTensorCUDA: NestedTensor_masked_fill 7948 tags: pointwise 7949 7950- func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) 7951 device_check: NoCheck # TensorIterator 7952 variants: method 7953 dispatch: 7954 CPU: masked_fill__cpu 7955 CUDA: masked_fill__cuda 7956 QuantizedCPU: masked_fill__quantized_cpu 7957 QuantizedCUDA: masked_fill__quantized_cuda 7958 MPS: masked_fill__mps 7959 autogen: masked_fill.Tensor_out 7960 7961- func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor 7962 device_check: NoCheck # TensorIterator 7963 variants: function, method 7964 dispatch: 7965 CompositeExplicitAutograd: masked_fill 7966 7967- func: masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) 7968 variants: method 7969 dispatch: 7970 CPU: masked_scatter__cpu 7971 CUDA: masked_scatter__cuda 7972 MPS: masked_scatter__mps 7973 autogen: masked_scatter.out 7974 7975- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor 7976 variants: function, method 7977 dispatch: 7978 CompositeExplicitAutograd: masked_scatter 7979 7980- func: masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor 7981 dispatch: 7982 CompositeExplicitAutograd: masked_scatter_backward_symint 7983 7984- func: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor 7985 dispatch: 7986 CUDA: masked_softmax_cuda 7987 CPU: masked_softmax_cpu 7988 autogen: _masked_softmax.out 7989 7990- func: _masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor 7991 dispatch: 7992 CUDA: masked_softmax_backward_cuda 7993 CPU: masked_softmax_backward_cpu 7994 autogen: _masked_softmax_backward.out 7995 7996- func: view(Tensor(a) self, SymInt[] size) -> Tensor(a) 7997 variants: method 7998 device_check: NoCheck 7999 device_guard: False 8000 dispatch: 8001 ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS: view 8002 MkldnnCPU: mkldnn_view 8003 NestedTensorCPU, NestedTensorCUDA: view_nested 8004 tags: core 8005 8006# Warning: If you want to change the name or overload name of this 8007# operator, you might also want to change the `isBlockListedSchema` 8008# function in `torch/csrc/jit/frontend/schema_catching.cpp`. 8009# The name and overload name of this operator is hardcoded in that 8010# function in order to workaround a bug: 8011# https://github.com/pytorch/pytorch/issues/47964 8012- func: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) 8013 variants: method 8014 device_check: NoCheck 8015 device_guard: False 8016 dispatch: 8017 CompositeExplicitAutograd: view_dtype 8018 8019- func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) 8020 variants: method 8021 dispatch: 8022 CPU, CUDA: put_ 8023 autogen: put.out 8024 8025- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor 8026 variants: function, method 8027 dispatch: 8028 CompositeExplicitAutograd: put 8029 8030- func: index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 8031 structured: True 8032 variants: function 8033 precomputed: 8034 - dim -> int dim 8035 dispatch: 8036 CPU: index_add_cpu_out 8037 CUDA: index_add_cuda_out 8038 MPS: index_add_mps_out 8039 8040- func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) 8041 structured_delegate: index_add.out 8042 variants: method 8043 8044- func: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor 8045 structured_delegate: index_add.out 8046 variants: function, method 8047 8048- func: index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor 8049 variants: function, method 8050 8051- func: index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) 8052 structured: True 8053 variants: function 8054 precomputed: 8055 - dim -> int dim 8056 dispatch: 8057 CPU: index_reduce_cpu_out 8058 CUDA: index_reduce_cuda_out 8059 8060- func: index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) 8061 structured_delegate: index_reduce.out 8062 variants: method 8063 8064- func: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor 8065 structured_delegate: index_reduce.out 8066 variants: function, method 8067 8068- func: index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) 8069 device_check: NoCheck # TensorIterator 8070 variants: method 8071 dispatch: 8072 CPU: index_fill_ 8073 CUDA: index_fill_ 8074 MPS: index_fill_mps_ 8075 autogen: index_fill.int_Scalar_out 8076 8077- func: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor 8078 device_check: NoCheck # TensorIterator 8079 variants: function, method 8080 dispatch: 8081 CompositeExplicitAutograd: index_fill 8082 8083- func: index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) 8084 device_check: NoCheck # TensorIterator 8085 variants: method 8086 dispatch: 8087 CPU, CUDA: index_fill_ 8088 MPS: index_fill_mps_ 8089 autogen: index_fill.int_Tensor_out 8090 8091- func: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor 8092 device_check: NoCheck # TensorIterator 8093 variants: function, method 8094 dispatch: 8095 CompositeExplicitAutograd: index_fill 8096 8097- func: index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!) 8098 device_check: NoCheck # TensorIterator 8099 variants: method 8100 8101- func: index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!) 8102 device_check: NoCheck # TensorIterator 8103 variants: method 8104 8105- func: index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor 8106 device_check: NoCheck # TensorIterator 8107 variants: function, method 8108 8109- func: index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor 8110 device_check: NoCheck # TensorIterator 8111 variants: function, method 8112 8113- func: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor 8114 structured_delegate: scatter.src_out 8115 variants: function, method 8116 tags: core 8117 8118- func: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) 8119 structured_delegate: scatter.src_out 8120 variants: method 8121 8122- func: scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) 8123 structured: True 8124 variants: function 8125 dispatch: 8126 CPU, CUDA: scatter_src_out 8127 MPS: scatter_src_out_mps 8128 8129- func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor 8130 structured_delegate: scatter.value_out 8131 variants: function, method 8132 tags: core 8133 8134- func: scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) 8135 structured_delegate: scatter.value_out 8136 variants: method 8137 8138- func: scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) 8139 structured: True 8140 variants: function 8141 dispatch: 8142 CPU, CUDA: scatter_value_out 8143 MPS: scatter_value_out_mps 8144 8145- func: scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor 8146 structured_delegate: scatter.reduce_out 8147 variants: function, method 8148 8149- func: scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!) 8150 structured_delegate: scatter.reduce_out 8151 variants: method 8152 8153- func: scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!) 8154 structured: True 8155 variants: function 8156 dispatch: 8157 CPU, CUDA: scatter_reduce_out 8158 MPS: scatter_reduce_out_mps 8159 8160- func: scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor 8161 structured_delegate: scatter.value_reduce_out 8162 variants: function, method 8163 8164- func: scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!) 8165 structured_delegate: scatter.value_reduce_out 8166 variants: method 8167 8168- func: scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!) 8169 structured: True 8170 variants: function 8171 dispatch: 8172 CPU, CUDA: scatter_value_reduce_out 8173 MPS: scatter_value_reduce_out_mps 8174 8175- func: scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor 8176 variants: function, method 8177 8178- func: scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor 8179 variants: function, method 8180 8181- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor 8182 structured_delegate: scatter_add.out 8183 variants: function, method 8184 tags: core 8185 8186- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) 8187 structured_delegate: scatter_add.out 8188 variants: method 8189 8190- func: scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) 8191 structured: True 8192 variants: function 8193 dispatch: 8194 CPU, CUDA: scatter_add 8195 MPS: scatter_add_mps_out 8196 8197- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor 8198 variants: function, method 8199 8200- func: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor 8201 structured_delegate: scatter_reduce.two_out 8202 variants: function, method 8203 tags: core 8204 8205- func: scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!) 8206 structured_delegate: scatter_reduce.two_out 8207 variants: method 8208 8209- func: scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) 8210 structured: True 8211 variants: function 8212 dispatch: 8213 CPU, CUDA: scatter_reduce_two 8214 8215- func: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8216 structured_delegate: eq.Scalar_out 8217 device_check: NoCheck # TensorIterator 8218 variants: method 8219 8220- func: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8221 structured_delegate: eq.Tensor_out 8222 device_check: NoCheck # TensorIterator 8223 variants: method 8224 8225- func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8226 device_check: NoCheck # TensorIterator 8227 structured: True 8228 structured_inherits: TensorIteratorBase 8229 variants: function 8230 dispatch: 8231 CPU, CUDA: bitwise_and_out 8232 MPS: bitwise_and_out_mps 8233 tags: pointwise 8234 8235- func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8236 device_check: NoCheck # TensorIterator 8237 variants: function 8238 dispatch: 8239 CompositeExplicitAutograd: bitwise_and_out 8240 tags: pointwise 8241 8242- func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor 8243 device_check: NoCheck # TensorIterator 8244 variants: method, function 8245 dispatch: 8246 CompositeExplicitAutograd: bitwise_and 8247 tags: [core, pointwise] 8248 8249- func: bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor 8250 device_check: NoCheck # TensorIterator 8251 variants: function 8252 dispatch: 8253 CompositeExplicitAutograd: bitwise_and 8254 autogen: bitwise_and.Scalar_Tensor_out 8255 tags: pointwise 8256 8257- func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor 8258 device_check: NoCheck # TensorIterator 8259 variants: method, function 8260 structured_delegate: bitwise_and.Tensor_out 8261 tags: [core, pointwise] 8262 8263- func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8264 device_check: NoCheck # TensorIterator 8265 variants: method 8266 dispatch: 8267 CompositeExplicitAutograd: bitwise_and_ 8268 tags: pointwise 8269 8270- func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8271 device_check: NoCheck # TensorIterator 8272 variants: method 8273 structured_delegate: bitwise_and.Tensor_out 8274 tags: pointwise 8275 8276- func: __and__.Scalar(Tensor self, Scalar other) -> Tensor 8277 device_check: NoCheck # TensorIterator 8278 variants: method, function 8279 8280- func: __and__.Tensor(Tensor self, Tensor other) -> Tensor 8281 device_check: NoCheck # TensorIterator 8282 variants: method, function 8283 8284- func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8285 device_check: NoCheck # TensorIterator 8286 variants: method 8287 8288- func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8289 device_check: NoCheck # TensorIterator 8290 variants: method 8291 8292- func: bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8293 device_check: NoCheck # TensorIterator 8294 structured: True 8295 structured_inherits: TensorIteratorBase 8296 variants: function 8297 dispatch: 8298 CPU, CUDA: bitwise_or_out 8299 MPS: bitwise_or_out_mps 8300 tags: pointwise 8301 8302- func: bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8303 device_check: NoCheck # TensorIterator 8304 variants: function 8305 dispatch: 8306 CompositeExplicitAutograd: bitwise_or_out 8307 tags: pointwise 8308 8309- func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor 8310 device_check: NoCheck # TensorIterator 8311 variants: method, function 8312 dispatch: 8313 CompositeExplicitAutograd: bitwise_or 8314 tags: [core, pointwise] 8315 8316- func: bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor 8317 device_check: NoCheck # TensorIterator 8318 variants: function 8319 dispatch: 8320 CompositeExplicitAutograd: bitwise_or 8321 autogen: bitwise_or.Scalar_Tensor_out 8322 tags: pointwise 8323 8324- func: bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor 8325 device_check: NoCheck # TensorIterator 8326 variants: method, function 8327 structured_delegate: bitwise_or.Tensor_out 8328 tags: [core, pointwise] 8329 8330- func: bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8331 device_check: NoCheck # TensorIterator 8332 variants: method 8333 dispatch: 8334 CompositeExplicitAutograd: bitwise_or_ 8335 tags: pointwise 8336 8337- func: bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8338 device_check: NoCheck # TensorIterator 8339 variants: method 8340 structured_delegate: bitwise_or.Tensor_out 8341 tags: pointwise 8342 8343- func: __or__.Scalar(Tensor self, Scalar other) -> Tensor 8344 device_check: NoCheck # TensorIterator 8345 variants: method, function 8346 8347- func: __or__.Tensor(Tensor self, Tensor other) -> Tensor 8348 device_check: NoCheck # TensorIterator 8349 variants: method, function 8350 8351- func: __ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8352 device_check: NoCheck # TensorIterator 8353 variants: method 8354 8355- func: __ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8356 device_check: NoCheck # TensorIterator 8357 variants: method 8358 8359- func: bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8360 device_check: NoCheck # TensorIterator 8361 structured: True 8362 structured_inherits: TensorIteratorBase 8363 variants: function 8364 dispatch: 8365 CPU, CUDA: bitwise_xor_out 8366 MPS: bitwise_xor_out_mps 8367 tags: pointwise 8368 8369- func: bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8370 device_check: NoCheck # TensorIterator 8371 variants: function 8372 dispatch: 8373 CompositeExplicitAutograd: bitwise_xor_out 8374 tags: pointwise 8375 8376- func: bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor 8377 device_check: NoCheck # TensorIterator 8378 variants: method, function 8379 dispatch: 8380 CompositeExplicitAutograd: bitwise_xor 8381 tags: [core, pointwise] 8382 8383- func: bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor 8384 device_check: NoCheck # TensorIterator 8385 variants: function 8386 dispatch: 8387 CompositeExplicitAutograd: bitwise_xor 8388 autogen: bitwise_xor.Scalar_Tensor_out 8389 tags: pointwise 8390 8391- func: bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor 8392 device_check: NoCheck # TensorIterator 8393 variants: method, function 8394 structured_delegate: bitwise_xor.Tensor_out 8395 tags: [core, pointwise] 8396 8397- func: bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8398 device_check: NoCheck # TensorIterator 8399 variants: method 8400 dispatch: 8401 CompositeExplicitAutograd: bitwise_xor_ 8402 tags: pointwise 8403 8404- func: bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8405 device_check: NoCheck # TensorIterator 8406 variants: method 8407 structured_delegate: bitwise_xor.Tensor_out 8408 tags: pointwise 8409 8410- func: __xor__.Scalar(Tensor self, Scalar other) -> Tensor 8411 device_check: NoCheck # TensorIterator 8412 variants: method, function 8413 tags: pointwise 8414 8415- func: __xor__.Tensor(Tensor self, Tensor other) -> Tensor 8416 device_check: NoCheck # TensorIterator 8417 variants: method, function 8418 tags: pointwise 8419 8420- func: __ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8421 device_check: NoCheck # TensorIterator 8422 variants: method 8423 tags: pointwise 8424 8425- func: __ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8426 device_check: NoCheck # TensorIterator 8427 variants: method 8428 tags: pointwise 8429 8430- func: __lshift__.Scalar(Tensor self, Scalar other) -> Tensor 8431 device_check: NoCheck # TensorIterator 8432 variants: method, function 8433 dispatch: 8434 CPU, CUDA: __lshift__ 8435 tags: pointwise 8436 8437- func: __lshift__.Tensor(Tensor self, Tensor other) -> Tensor 8438 device_check: NoCheck # TensorIterator 8439 variants: method, function 8440 dispatch: 8441 CPU, CUDA: __lshift__ 8442 tags: pointwise 8443 8444- func: __ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8445 device_check: NoCheck # TensorIterator 8446 variants: method 8447 dispatch: 8448 CPU, CUDA: __ilshift__ 8449 autogen: __lshift__.Scalar_out 8450 tags: pointwise 8451 8452- func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8453 device_check: NoCheck # TensorIterator 8454 variants: method 8455 dispatch: 8456 CPU, CUDA: __ilshift__ 8457 autogen: __lshift__.Tensor_out 8458 tags: pointwise 8459 8460- func: bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor 8461 device_check: NoCheck # TensorIterator 8462 variants: function, method 8463 structured_delegate: bitwise_left_shift.Tensor_out 8464 tags: pointwise 8465 8466- func: bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8467 device_check: NoCheck # TensorIterator 8468 variants: method 8469 structured_delegate: bitwise_left_shift.Tensor_out 8470 tags: pointwise 8471 8472- func: bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8473 device_check: NoCheck # TensorIterator 8474 structured: True 8475 structured_inherits: TensorIteratorBase 8476 dispatch: 8477 CPU, CUDA: bitwise_left_shift_out 8478 tags: pointwise 8479 8480- func: bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor 8481 device_check: NoCheck # TensorIterator 8482 variants: method, function 8483 dispatch: 8484 CompositeExplicitAutograd: bitwise_left_shift 8485 tags: pointwise 8486 8487- func: bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8488 device_check: NoCheck # TensorIterator 8489 variants: method 8490 dispatch: 8491 CompositeExplicitAutograd: bitwise_left_shift_ 8492 tags: pointwise 8493 8494- func: bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8495 device_check: NoCheck # TensorIterator 8496 variants: function 8497 dispatch: 8498 CompositeExplicitAutograd: bitwise_left_shift_out 8499 tags: pointwise 8500 8501- func: bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor 8502 device_check: NoCheck # TensorIterator 8503 variants: function 8504 dispatch: 8505 CompositeExplicitAutograd: bitwise_left_shift 8506 autogen: bitwise_left_shift.Scalar_Tensor_out 8507 tags: pointwise 8508 8509- func: __rshift__.Scalar(Tensor self, Scalar other) -> Tensor 8510 device_check: NoCheck # TensorIterator 8511 variants: method, function 8512 dispatch: 8513 CPU, CUDA: __rshift__ 8514 tags: pointwise 8515 8516- func: __rshift__.Tensor(Tensor self, Tensor other) -> Tensor 8517 device_check: NoCheck # TensorIterator 8518 variants: method, function 8519 dispatch: 8520 CPU, CUDA: __rshift__ 8521 tags: pointwise 8522 8523- func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8524 device_check: NoCheck # TensorIterator 8525 variants: method 8526 dispatch: 8527 CPU, CUDA: __irshift__ 8528 autogen: __rshift__.Scalar_out 8529 8530- func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8531 device_check: NoCheck # TensorIterator 8532 variants: method 8533 dispatch: 8534 CPU, CUDA: __irshift__ 8535 autogen: __rshift__.Tensor_out 8536 8537- func: bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor 8538 device_check: NoCheck # TensorIterator 8539 variants: function, method 8540 structured_delegate: bitwise_right_shift.Tensor_out 8541 tags: pointwise 8542 8543- func: bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8544 device_check: NoCheck # TensorIterator 8545 variants: method 8546 structured_delegate: bitwise_right_shift.Tensor_out 8547 tags: pointwise 8548 8549- func: bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8550 device_check: NoCheck # TensorIterator 8551 structured: True 8552 structured_inherits: TensorIteratorBase 8553 dispatch: 8554 CPU, CUDA: bitwise_right_shift_out 8555 tags: pointwise 8556 8557- func: bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor 8558 device_check: NoCheck # TensorIterator 8559 variants: method, function 8560 dispatch: 8561 CompositeExplicitAutograd: bitwise_right_shift 8562 tags: pointwise 8563 8564- func: bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8565 device_check: NoCheck # TensorIterator 8566 variants: method 8567 dispatch: 8568 CompositeExplicitAutograd: bitwise_right_shift_ 8569 tags: pointwise 8570 8571- func: bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8572 device_check: NoCheck # TensorIterator 8573 variants: function 8574 dispatch: 8575 CompositeExplicitAutograd: bitwise_right_shift_out 8576 tags: pointwise 8577 8578- func: bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor 8579 device_check: NoCheck # TensorIterator 8580 variants: function 8581 dispatch: 8582 CompositeExplicitAutograd: bitwise_right_shift 8583 autogen: bitwise_right_shift.Scalar_Tensor_out 8584 tags: pointwise 8585 8586- func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) 8587 structured_delegate: tril.out 8588 variants: method 8589 8590- func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) 8591 structured_delegate: triu.out 8592 variants: method 8593 8594- func: digamma_(Tensor(a!) self) -> Tensor(a!) 8595 device_check: NoCheck # TensorIterator 8596 structured_delegate: digamma.out 8597 variants: method 8598 tags: pointwise 8599 8600- func: lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) 8601 device_check: NoCheck # TensorIterator 8602 variants: method 8603 structured_delegate: lerp.Scalar_out 8604 tags: pointwise 8605 8606- func: lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) 8607 device_check: NoCheck # TensorIterator 8608 variants: method 8609 structured_delegate: lerp.Tensor_out 8610 tags: pointwise 8611 8612- func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) 8613 variants: method 8614 dispatch: 8615 CPU, CUDA: addbmm_ 8616 MPS: addbmm_mps_ 8617 8618- func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 8619 dispatch: 8620 CPU, CUDA: addbmm_out 8621 MPS: addbmm_out_mps 8622 8623- func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor 8624 variants: method, function 8625 dispatch: 8626 CPU, CUDA: addbmm 8627 MPS: addbmm_mps 8628 8629- func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) 8630 device_check: NoCheck # TensorIterator 8631 variants: method 8632 tags: nondeterministic_seeded 8633 dispatch: 8634 CPU, CUDA: random_ 8635 Meta: random_meta_ 8636 MPS: random_mps_ 8637 autogen: random.from, random.from_out 8638 8639- func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) 8640 device_check: NoCheck # TensorIterator 8641 tags: nondeterministic_seeded 8642 variants: method 8643 dispatch: 8644 CPU, CUDA: random_ 8645 Meta: random_meta_ 8646 MPS: random_mps_ 8647 autogen: random.to, random.to_out 8648 8649- func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) 8650 device_check: NoCheck # TensorIterator 8651 tags: nondeterministic_seeded 8652 variants: method 8653 dispatch: 8654 CPU, CUDA: random_ 8655 MPS: random_mps_ 8656 Meta: random_meta_ 8657 autogen: random, random.out 8658 8659- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) 8660 device_check: NoCheck # TensorIterator 8661 tags: nondeterministic_seeded 8662 variants: method 8663 dispatch: 8664 CPU, CUDA: uniform_ 8665 MPS: uniform_mps_ 8666 Meta: uniform_meta_ 8667 autogen: uniform, uniform.out 8668 8669- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) 8670 device_check: NoCheck # TensorIterator 8671 variants: method 8672 tags: nondeterministic_seeded 8673 dispatch: 8674 CPU, CUDA: cauchy_ 8675 autogen: cauchy, cauchy.out 8676 8677- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) 8678 device_check: NoCheck # TensorIterator 8679 tags: nondeterministic_seeded 8680 variants: method 8681 dispatch: 8682 CPU, CUDA: log_normal_ 8683 autogen: log_normal, log_normal.out 8684 8685- func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) 8686 device_check: NoCheck # TensorIterator 8687 tags: nondeterministic_seeded 8688 variants: method 8689 dispatch: 8690 CPU, CUDA: exponential_ 8691 MPS: exponential_mps_ 8692 autogen: exponential, exponential.out 8693 8694- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) 8695 device_check: NoCheck # TensorIterator 8696 tags: nondeterministic_seeded 8697 variants: method 8698 dispatch: 8699 CPU, CUDA: geometric_ 8700 8701 # wrappers for TH functions 8702 autogen: geometric, geometric.out 8703 8704- func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) 8705 8706- func: diag(Tensor self, int diagonal=0) -> Tensor 8707 variants: method, function 8708 8709- func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) 8710 8711- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor 8712 variants: method, function 8713 8714- func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) 8715 structured: True 8716 dispatch: 8717 CPU: triu_cpu 8718 CUDA: triu_cuda 8719 MPS: triu_mps_out 8720 8721- func: triu(Tensor self, int diagonal=0) -> Tensor 8722 structured_delegate: triu.out 8723 variants: method, function 8724 8725- func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) 8726 structured: True 8727 dispatch: 8728 CPU: tril_cpu 8729 CUDA: tril_cuda 8730 MPS: tril_mps_out 8731 8732- func: tril(Tensor self, int diagonal=0) -> Tensor 8733 structured_delegate: tril.out 8734 variants: method, function 8735 8736- func: tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 8737 dispatch: 8738 CPU: tril_indices_cpu 8739 CUDA: tril_indices_cuda 8740 autogen: tril_indices.out 8741 8742- func: triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 8743 dispatch: 8744 CPU: triu_indices_cpu 8745 CUDA: triu_indices_cuda 8746 autogen: triu_indices.out 8747 8748- func: trace(Tensor self) -> Tensor 8749 variants: method, function 8750 dispatch: 8751 CPU: trace_cpu 8752 CUDA: trace_cuda 8753 MPS: trace_mps 8754 autogen: trace.out 8755 8756- func: trace_backward(Tensor grad, SymInt[] sizes) -> Tensor 8757 variants: function 8758 device_check: NoCheck 8759 device_guard: False 8760 dispatch: 8761 CompositeImplicitAutograd: trace_backward_symint 8762 8763- func: ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8764 structured: True 8765 structured_inherits: TensorIteratorBase 8766 device_check: NoCheck # TensorIterator 8767 dispatch: 8768 CPU, CUDA: ne_Scalar_out 8769 MPS: ne_scalar_out_mps 8770 QuantizedCPU: ne_out_quantized_cpu 8771 tags: pointwise 8772 8773- func: ne.Scalar(Tensor self, Scalar other) -> Tensor 8774 structured_delegate: ne.Scalar_out 8775 device_check: NoCheck # TensorIterator 8776 variants: method, function 8777 dispatch: 8778 QuantizedCPU: ne_quantized_cpu 8779 tags: [core, pointwise] 8780 8781- func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8782 structured: True 8783 structured_inherits: TensorIteratorBase 8784 device_check: NoCheck # TensorIterator 8785 dispatch: 8786 CPU, CUDA: ne_Tensor_out 8787 MPS: ne_tensor_out_mps 8788 QuantizedCPU: ne_out_quantized_cpu 8789 tags: pointwise 8790 8791- func: ne.Tensor(Tensor self, Tensor other) -> Tensor 8792 structured_delegate: ne.Tensor_out 8793 device_check: NoCheck # TensorIterator 8794 variants: method, function 8795 dispatch: 8796 QuantizedCPU: ne_quantized_cpu 8797 tags: [core, pointwise] 8798 8799- func: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8800 structured_delegate: ne.Scalar_out 8801 device_check: NoCheck # TensorIterator 8802 variants: method 8803 8804- func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8805 structured_delegate: ne.Tensor_out 8806 device_check: NoCheck # TensorIterator 8807 variants: method 8808 8809# not_equal, alias for torch.ne 8810- func: not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8811 8812- func: not_equal.Scalar(Tensor self, Scalar other) -> Tensor 8813 variants: method, function 8814 8815- func: not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8816 8817- func: not_equal.Tensor(Tensor self, Tensor other) -> Tensor 8818 variants: method, function 8819 8820- func: not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8821 variants: method 8822 8823- func: not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8824 variants: method 8825 8826- func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8827 structured: True 8828 structured_inherits: TensorIteratorBase 8829 device_check: NoCheck # TensorIterator 8830 dispatch: 8831 CPU, CUDA: eq_Scalar_out 8832 MPS: eq_scalar_out_mps 8833 QuantizedCPU: eq_out_quantized_cpu 8834 tags: pointwise 8835 8836- func: eq.Scalar(Tensor self, Scalar other) -> Tensor 8837 structured_delegate: eq.Scalar_out 8838 device_check: NoCheck # TensorIterator 8839 variants: method, function 8840 dispatch: 8841 QuantizedCPU: eq_quantized_cpu 8842 NestedTensorCPU, NestedTensorCUDA: eq_scalar_nested 8843 tags: [core, pointwise] 8844 8845- func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8846 structured: True 8847 structured_inherits: TensorIteratorBase 8848 device_check: NoCheck # TensorIterator 8849 dispatch: 8850 CPU, CUDA: eq_Tensor_out 8851 MPS: eq_tensor_out_mps 8852 QuantizedCPU: eq_out_quantized_cpu 8853 tags: pointwise 8854 8855- func: eq.Tensor(Tensor self, Tensor other) -> Tensor 8856 structured_delegate: eq.Tensor_out 8857 device_check: NoCheck # TensorIterator 8858 variants: method, function 8859 dispatch: 8860 QuantizedCPU: eq_quantized_cpu 8861 tags: [core, pointwise] 8862 8863- func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8864 structured: True 8865 structured_inherits: TensorIteratorBase 8866 device_check: NoCheck # TensorIterator 8867 dispatch: 8868 CPU, CUDA: ge_Scalar_out 8869 MPS: ge_scalar_out_mps 8870 QuantizedCPU: ge_out_quantized_cpu 8871 tags: pointwise 8872 8873- func: ge.Scalar(Tensor self, Scalar other) -> Tensor 8874 structured_delegate: ge.Scalar_out 8875 device_check: NoCheck # TensorIterator 8876 variants: method, function 8877 dispatch: 8878 QuantizedCPU: ge_quantized_cpu 8879 NestedTensorCPU, NestedTensorCUDA: ge_scalar_nested 8880 tags: [core, pointwise] 8881 8882- func: ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8883 structured: True 8884 structured_inherits: TensorIteratorBase 8885 device_check: NoCheck # TensorIterator 8886 dispatch: 8887 CPU, CUDA: ge_Tensor_out 8888 MPS: ge_tensor_out_mps 8889 QuantizedCPU: ge_out_quantized_cpu 8890 tags: pointwise 8891 8892- func: ge.Tensor(Tensor self, Tensor other) -> Tensor 8893 structured_delegate: ge.Tensor_out 8894 device_check: NoCheck # TensorIterator 8895 variants: method, function 8896 dispatch: 8897 QuantizedCPU: ge_quantized_cpu 8898 tags: [core, pointwise] 8899 8900- func: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8901 structured_delegate: ge.Scalar_out 8902 device_check: NoCheck # TensorIterator 8903 variants: method 8904 8905- func: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8906 structured_delegate: ge.Tensor_out 8907 device_check: NoCheck # TensorIterator 8908 variants: method 8909 8910# greater_equal, alias for torch.ge 8911- func: greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8912 8913- func: greater_equal.Scalar(Tensor self, Scalar other) -> Tensor 8914 variants: method, function 8915 8916- func: greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8917 8918- func: greater_equal.Tensor(Tensor self, Tensor other) -> Tensor 8919 variants: method, function 8920 8921- func: greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8922 variants: method 8923 8924- func: greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8925 variants: method 8926 8927- func: le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8928 structured: True 8929 structured_inherits: TensorIteratorBase 8930 device_check: NoCheck # TensorIterator 8931 dispatch: 8932 CPU, CUDA: le_Scalar_out 8933 MPS: le_scalar_out_mps 8934 QuantizedCPU: le_out_quantized_cpu 8935 tags: pointwise 8936 8937- func: le.Scalar(Tensor self, Scalar other) -> Tensor 8938 structured_delegate: le.Scalar_out 8939 device_check: NoCheck # TensorIterator 8940 variants: method, function 8941 dispatch: 8942 QuantizedCPU: le_quantized_cpu 8943 tags: [core, pointwise] 8944 8945- func: le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8946 structured: True 8947 structured_inherits: TensorIteratorBase 8948 device_check: NoCheck # TensorIterator 8949 dispatch: 8950 CPU, CUDA: le_Tensor_out 8951 MPS: le_tensor_out_mps 8952 QuantizedCPU: le_out_quantized_cpu 8953 tags: pointwise 8954 8955- func: le.Tensor(Tensor self, Tensor other) -> Tensor 8956 structured_delegate: le.Tensor_out 8957 device_check: NoCheck # TensorIterator 8958 variants: method, function 8959 dispatch: 8960 QuantizedCPU: le_quantized_cpu 8961 tags: [core, pointwise] 8962 8963- func: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8964 structured_delegate: le.Scalar_out 8965 device_check: NoCheck # TensorIterator 8966 variants: method 8967 8968- func: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8969 structured_delegate: le.Tensor_out 8970 device_check: NoCheck # TensorIterator 8971 variants: method 8972 8973# less_equal, alias for torch.le 8974- func: less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8975 8976- func: less_equal.Scalar(Tensor self, Scalar other) -> Tensor 8977 variants: method, function 8978 8979- func: less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 8980 8981- func: less_equal.Tensor(Tensor self, Tensor other) -> Tensor 8982 variants: method, function 8983 8984- func: less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 8985 variants: method 8986 8987- func: less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 8988 variants: method 8989 8990- func: gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 8991 structured: True 8992 structured_inherits: TensorIteratorBase 8993 device_check: NoCheck # TensorIterator 8994 dispatch: 8995 CPU, CUDA: gt_Scalar_out 8996 MPS: gt_scalar_out_mps 8997 QuantizedCPU: gt_out_quantized_cpu 8998 tags: pointwise 8999 9000- func: gt.Scalar(Tensor self, Scalar other) -> Tensor 9001 structured_delegate: gt.Scalar_out 9002 device_check: NoCheck # TensorIterator 9003 variants: method, function 9004 dispatch: 9005 QuantizedCPU: gt_quantized_cpu 9006 NestedTensorCPU, NestedTensorCUDA: gt_scalar_nested 9007 tags: [core, pointwise] 9008 9009- func: gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9010 structured: True 9011 structured_inherits: TensorIteratorBase 9012 device_check: NoCheck # TensorIterator 9013 dispatch: 9014 CPU, CUDA: gt_Tensor_out 9015 MPS: gt_tensor_out_mps 9016 QuantizedCPU: gt_out_quantized_cpu 9017 tags: pointwise 9018 9019- func: gt.Tensor(Tensor self, Tensor other) -> Tensor 9020 structured_delegate: gt.Tensor_out 9021 device_check: NoCheck # TensorIterator 9022 variants: method, function 9023 dispatch: 9024 QuantizedCPU: gt_quantized_cpu 9025 tags: [core, pointwise] 9026 9027- func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 9028 structured_delegate: gt.Scalar_out 9029 device_check: NoCheck # TensorIterator 9030 variants: method 9031 9032- func: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 9033 structured_delegate: gt.Tensor_out 9034 device_check: NoCheck # TensorIterator 9035 variants: method 9036 9037# greater, alias for torch.gt 9038- func: greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 9039 9040- func: greater.Scalar(Tensor self, Scalar other) -> Tensor 9041 variants: method, function 9042 9043- func: greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9044 9045- func: greater.Tensor(Tensor self, Tensor other) -> Tensor 9046 variants: method, function 9047 9048- func: greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 9049 variants: method 9050 9051- func: greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 9052 variants: method 9053 9054- func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 9055 structured: True 9056 structured_inherits: TensorIteratorBase 9057 device_check: NoCheck # TensorIterator 9058 dispatch: 9059 CPU, CUDA: lt_Scalar_out 9060 MPS: lt_scalar_out_mps 9061 QuantizedCPU: lt_out_quantized_cpu 9062 tags: pointwise 9063 9064- func: lt.Scalar(Tensor self, Scalar other) -> Tensor 9065 structured_delegate: lt.Scalar_out 9066 device_check: NoCheck # TensorIterator 9067 variants: method, function 9068 dispatch: 9069 QuantizedCPU: lt_quantized_cpu 9070 tags: [core, pointwise] 9071 9072- func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9073 structured: True 9074 structured_inherits: TensorIteratorBase 9075 device_check: NoCheck # TensorIterator 9076 dispatch: 9077 CPU, CUDA: lt_Tensor_out 9078 MPS: lt_tensor_out_mps 9079 QuantizedCPU: lt_out_quantized_cpu 9080 tags: pointwise 9081 9082- func: lt.Tensor(Tensor self, Tensor other) -> Tensor 9083 structured_delegate: lt.Tensor_out 9084 device_check: NoCheck # TensorIterator 9085 variants: method, function 9086 dispatch: 9087 QuantizedCPU: lt_quantized_cpu 9088 tags: [core, pointwise] 9089 9090- func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 9091 structured_delegate: lt.Scalar_out 9092 device_check: NoCheck # TensorIterator 9093 variants: method 9094 9095- func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 9096 structured_delegate: lt.Tensor_out 9097 device_check: NoCheck # TensorIterator 9098 variants: method 9099 9100# less, alias for torch.lt 9101- func: less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 9102 9103- func: less.Scalar(Tensor self, Scalar other) -> Tensor 9104 variants: method, function 9105 9106- func: less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9107 9108- func: less.Tensor(Tensor self, Tensor other) -> Tensor 9109 variants: method, function 9110 9111- func: less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 9112 variants: method 9113 9114- func: less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 9115 variants: method 9116 9117- func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) 9118 dispatch: 9119 CPU, CUDA: take_out 9120 9121- func: take(Tensor self, Tensor index) -> Tensor 9122 variants: method, function 9123 dispatch: 9124 CPU, CUDA: take 9125 9126- func: take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) 9127 9128- func: take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor 9129 variants: method, function 9130 9131- func: index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) 9132 dispatch: 9133 CPU, QuantizedCPU: index_select_out_cpu_ 9134 CUDA, QuantizedCUDA: index_select_out_cuda 9135 MPS: index_select_out_mps 9136 9137- func: index_select(Tensor self, int dim, Tensor index) -> Tensor 9138 variants: method, function 9139 dispatch: 9140 CPU: index_select_cpu_ 9141 QuantizedCPU: index_select_quantized_cpu_ 9142 CUDA: index_select_cuda 9143 QuantizedCUDA: index_select_quantized_cuda 9144 SparseCPU: index_select_sparse_cpu 9145 SparseCUDA: index_select_sparse_cuda 9146 MPS: index_select_mps 9147 tags: core 9148 9149- func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) 9150 9151- func: index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor 9152 variants: method, function 9153 9154- func: index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor 9155 variants: function 9156 device_check: NoCheck 9157 device_guard: False 9158 dispatch: 9159 CompositeImplicitAutograd: index_select_backward_symint 9160 9161- func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) 9162 dispatch: 9163 CPU: masked_select_out_cpu 9164 CUDA: masked_select_out_cuda 9165 MPS: masked_select_out_mps 9166 tags: dynamic_output_shape 9167 9168- func: masked_select(Tensor self, Tensor mask) -> Tensor 9169 variants: method, function 9170 dispatch: 9171 CPU: masked_select_cpu 9172 CUDA: masked_select_cuda 9173 MPS: masked_select_mps 9174 tags: dynamic_output_shape 9175 9176- func: masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor 9177 variants: function 9178 device_check: NoCheck 9179 device_guard: False 9180 9181- func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9182 dispatch: 9183 CPU: nonzero_out_cpu 9184 CUDA: nonzero_out_cuda 9185 MPS: nonzero_out_mps 9186 tags: dynamic_output_shape 9187 9188- func: nonzero(Tensor self) -> Tensor 9189 variants: method, function 9190 dispatch: 9191 CPU: nonzero_cpu 9192 CUDA: nonzero_cuda 9193 MPS: nonzero_mps 9194 tags: [dynamic_output_shape, core] 9195 9196- func: nonzero_static.out(Tensor self, *, int size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) 9197 dispatch: 9198 CPU: nonzero_static_out_cpu 9199 9200- func: nonzero_static(Tensor self, *, int size, int fill_value=-1) -> Tensor 9201 variants: method, function 9202 dispatch: 9203 CPU: nonzero_static_cpu 9204 9205- func: nonzero_numpy(Tensor self) -> Tensor[] 9206 variants: method, function 9207 9208- func: argwhere(Tensor self) -> Tensor 9209 variants: method, function 9210 tags: dynamic_output_shape 9211 9212- func: gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) 9213 structured: True 9214 dispatch: 9215 CPU, CUDA: gather_out 9216 MPS: gather_out_mps 9217 9218- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor 9219 variants: method, function 9220 structured_delegate: gather.out 9221 tags: core 9222 9223- func: gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor 9224 variants: function 9225 device_check: NoCheck 9226 device_guard: False 9227 9228- func: gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) 9229 9230- func: gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor 9231 variants: method, function 9232 9233- func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor 9234 9235- func: addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) 9236 structured: True 9237 structured_inherits: TensorIteratorBase 9238 device_check: NoCheck # TensorIterator 9239 dispatch: 9240 CPU, CUDA: addcmul_out 9241 MPS: addcmul_out_mps 9242 tags: pointwise 9243 9244- func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor 9245 structured_delegate: addcmul.out 9246 device_check: NoCheck # TensorIterator 9247 variants: method, function 9248 tags: pointwise 9249 9250- func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) 9251 structured_delegate: addcmul.out 9252 device_check: NoCheck # TensorIterator 9253 variants: method 9254 tags: pointwise 9255 9256- func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) 9257 structured: True 9258 structured_inherits: TensorIteratorBase 9259 device_check: NoCheck # TensorIterator 9260 dispatch: 9261 CPU, CUDA: addcdiv_out 9262 MPS: addcdiv_out_mps 9263 tags: pointwise 9264 9265- func: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor 9266 structured_delegate: addcdiv.out 9267 device_check: NoCheck # TensorIterator 9268 variants: method, function 9269 tags: pointwise 9270 9271- func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) 9272 structured_delegate: addcdiv.out 9273 device_check: NoCheck # TensorIterator 9274 variants: method 9275 tags: pointwise 9276 9277- func: cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor 9278 python_module: nn 9279 dispatch: 9280 CompositeImplicitAutograd: cross_entropy_loss_symint 9281 9282- func: triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) 9283 structured: True 9284 dispatch: 9285 CPU, CUDA: triangular_solve_out 9286 MPS: triangular_solve_mps_out 9287 SparseCsrCPU: triangular_solve_out_sparse_csr_cpu 9288 SparseCsrCUDA: triangular_solve_out_sparse_csr_cuda 9289 9290- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) 9291 structured_delegate: triangular_solve.X 9292 variants: method, function 9293 9294- func: _linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> () 9295 dispatch: 9296 CompositeExplicitAutograd: _linalg_check_errors 9297 9298- func: linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!) 9299 python_module: linalg 9300 dispatch: 9301 CPU, CUDA: linalg_solve_triangular_out 9302 MPS: linalg_solve_triangular_mps_out 9303 9304- func: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor 9305 python_module: linalg 9306 variants: function 9307 dispatch: 9308 CPU, CUDA: linalg_solve_triangular 9309 MPS: linalg_solve_triangular_mps 9310 9311- func: linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor 9312 python_module: linalg 9313 dispatch: 9314 CompositeImplicitAutograd: linalg_vander_symint 9315 9316- func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) 9317 9318- func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) 9319 variants: method, function 9320 9321# swapaxes, alias for transpose 9322- func: swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a) 9323 variants: function, method 9324 device_check: NoCheck 9325 device_guard: False 9326 9327- func: swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!) 9328 variants: method 9329 device_check: NoCheck 9330 device_guard: False 9331 tags: inplace_view 9332 9333# swapdims, alias for transpose 9334- func: swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a) 9335 variants: function, method 9336 device_check: NoCheck 9337 device_guard: False 9338 9339- func: swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) 9340 variants: method 9341 device_check: NoCheck 9342 device_guard: False 9343 tags: inplace_view 9344 9345- func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) 9346 dispatch: 9347 CPU, CUDA: cholesky_out 9348 9349- func: cholesky(Tensor self, bool upper=False) -> Tensor 9350 variants: method, function 9351 dispatch: 9352 CPU, CUDA: cholesky 9353 9354- func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) 9355 dispatch: 9356 CompositeExplicitAutograd: cholesky_solve_out 9357 9358- func: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor 9359 variants: method, function 9360 dispatch: 9361 CompositeExplicitAutograd: cholesky_solve 9362 9363- func: _cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor 9364 variants: function 9365 dispatch: 9366 CPU: _cholesky_solve_helper_cpu 9367 CUDA: _cholesky_solve_helper_cuda 9368 autogen: _cholesky_solve_helper.out 9369 9370- func: cholesky_inverse(Tensor self, bool upper=False) -> Tensor 9371 variants: method, function 9372 dispatch: 9373 CPU, CUDA: cholesky_inverse 9374 9375- func: cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) 9376 dispatch: 9377 CPU, CUDA: cholesky_inverse_out 9378 9379- func: qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) 9380 9381- func: qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) 9382 variants: method, function 9383 9384- func: geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau) 9385 dispatch: 9386 CPU, CUDA: geqrf_out 9387 9388- func: geqrf(Tensor self) -> (Tensor a, Tensor tau) 9389 variants: method, function 9390 dispatch: 9391 CPU, CUDA: geqrf 9392 9393# orgqr, alias for linalg_householder_product 9394- func: orgqr(Tensor self, Tensor input2) -> Tensor 9395 variants: method, function 9396 9397- func: orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!) 9398 9399- func: ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!) 9400 dispatch: 9401 CPU, CUDA: ormqr_out 9402 9403- func: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor 9404 variants: method, function 9405 dispatch: 9406 CPU, CUDA: ormqr 9407 9408- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) 9409 variants: function 9410 9411- func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) 9412 9413- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor 9414 variants: method, function 9415 9416# lu_unpack 9417- func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) 9418 structured_delegate: lu_unpack.out 9419 variants: function 9420 9421- func: lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) 9422 variants: function 9423 structured: True 9424 dispatch: 9425 CPU, CUDA: lu_unpack_out 9426 9427# TODO: remove dispatch section when porting TH CUDA to ATen 9428- func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) 9429 tags: nondeterministic_seeded 9430 dispatch: 9431 CPU, CUDA: multinomial_out 9432 MPS: multinomial_out_mps 9433 9434- func: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor 9435 variants: method, function 9436 dispatch: 9437 CPU, CUDA: multinomial 9438 MPS: multinomial_mps 9439 tags: nondeterministic_seeded 9440 9441- func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9442 device_check: NoCheck # TensorIterator 9443 structured: True 9444 structured_inherits: TensorIteratorBase 9445 dispatch: 9446 CPU, CUDA: lgamma_out 9447 MPS: lgamma_out_mps 9448 tags: pointwise 9449 9450- func: lgamma_(Tensor(a!) self) -> Tensor(a!) 9451 device_check: NoCheck # TensorIterator 9452 structured_delegate: lgamma.out 9453 variants: method 9454 tags: pointwise 9455 9456- func: lgamma(Tensor self) -> Tensor 9457 device_check: NoCheck # TensorIterator 9458 structured_delegate: lgamma.out 9459 variants: method, function 9460 tags: pointwise 9461 9462- func: digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9463 device_check: NoCheck # TensorIterator 9464 structured: True 9465 structured_inherits: TensorIteratorBase 9466 dispatch: 9467 CPU, CUDA: digamma_out 9468 MPS: digamma_out_mps 9469 tags: pointwise 9470 9471- func: digamma(Tensor self) -> Tensor 9472 device_check: NoCheck # TensorIterator 9473 structured_delegate: digamma.out 9474 variants: method, function 9475 tags: pointwise 9476 9477- func: polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9478 device_check: NoCheck # TensorIterator 9479 structured: True 9480 structured_inherits: TensorIteratorBase 9481 dispatch: 9482 CPU, CUDA: polygamma_out 9483 MPS: polygamma_out_mps 9484 tags: pointwise 9485 9486- func: polygamma(int n, Tensor self) -> Tensor 9487 device_check: NoCheck # TensorIterator 9488 structured_delegate: polygamma.out 9489 variants: method, function 9490 tags: pointwise 9491 9492- func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) 9493 device_check: NoCheck # TensorIterator 9494 variants: method 9495 dispatch: 9496 CompositeExplicitAutograd: polygamma_ 9497 tags: pointwise 9498 9499- func: erfinv(Tensor self) -> Tensor 9500 device_check: NoCheck # TensorIterator 9501 structured_delegate: erfinv.out 9502 variants: method, function 9503 dispatch: 9504 SparseCPU, SparseCUDA: erfinv_sparse 9505 SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr 9506 tags: pointwise 9507 9508- func: erfinv_(Tensor(a!) self) -> Tensor(a!) 9509 device_check: NoCheck # TensorIterator 9510 structured_delegate: erfinv.out 9511 variants: method 9512 dispatch: 9513 SparseCPU, SparseCUDA: erfinv_sparse_ 9514 SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr_ 9515 tags: pointwise 9516 9517- func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9518 device_check: NoCheck # TensorIterator 9519 structured: True 9520 structured_inherits: TensorIteratorBase 9521 dispatch: 9522 CPU, CUDA: erfinv_out 9523 MPS: erfinv_out_mps 9524 SparseCPU, SparseCUDA: erfinv_sparse_out 9525 SparseCsrCPU, SparseCsrCUDA: erfinv_sparse_csr_out 9526 tags: pointwise 9527 9528- func: i0(Tensor self) -> Tensor 9529 structured_delegate: i0.out 9530 variants: function, method 9531 tags: pointwise 9532 9533- func: i0_(Tensor(a!) self) -> Tensor(a!) 9534 structured_delegate: i0.out 9535 variants: function, method 9536 tags: pointwise 9537 9538- func: i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9539 structured: True 9540 structured_inherits: TensorIteratorBase 9541 dispatch: 9542 CPU, CUDA: i0_out 9543 tags: pointwise 9544 9545- func: sign(Tensor self) -> Tensor 9546 device_check: NoCheck # TensorIterator 9547 structured_delegate: sign.out 9548 variants: function, method 9549 dispatch: 9550 SparseCPU, SparseCUDA: sign_sparse 9551 SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr 9552 tags: [core, pointwise] 9553 9554- func: sign_(Tensor(a!) self) -> Tensor(a!) 9555 device_check: NoCheck # TensorIterator 9556 structured_delegate: sign.out 9557 variants: method 9558 dispatch: 9559 SparseCPU, SparseCUDA: sign_sparse_ 9560 SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr_ 9561 tags: pointwise 9562 9563- func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9564 device_check: NoCheck # TensorIterator 9565 structured: True 9566 structured_inherits: TensorIteratorBase 9567 dispatch: 9568 CPU, CUDA: sign_out 9569 MPS: sign_out_mps 9570 SparseCPU, SparseCUDA: sign_sparse_out 9571 SparseCsrCPU, SparseCsrCUDA: sign_sparse_csr_out 9572 tags: pointwise 9573 9574- func: signbit(Tensor self) -> Tensor 9575 variants: function, method 9576 structured_delegate: signbit.out 9577 dispatch: 9578 SparseCPU, SparseCUDA: signbit_sparse 9579 SparseCsrCPU, SparseCsrCUDA: signbit_sparse_csr 9580 tags: pointwise 9581 9582- func: signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9583 structured: True 9584 structured_inherits: TensorIteratorBase 9585 dispatch: 9586 CPU: signbit_out 9587 CUDA: signbit_out 9588 MPS: signbit_out_mps 9589 SparseCPU, SparseCUDA: signbit_sparse_out 9590 SparseCsrCPU, SparseCsrCUDA: signbit_sparse_csr_out 9591 tags: pointwise 9592 9593- func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor 9594 device_check: NoCheck # TensorIterator 9595 variants: method, function 9596 dispatch: 9597 CompositeExplicitAutograd: dist 9598 autogen: dist.out 9599 9600- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9601 device_check: NoCheck # TensorIterator 9602 structured: True 9603 structured_inherits: TensorIteratorBase 9604 dispatch: 9605 CPU, CUDA: atan2_out 9606 MPS: atan2_out_mps 9607 tags: [core, pointwise] 9608 9609- func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) 9610 device_check: NoCheck # TensorIterator 9611 structured_delegate: atan2.out 9612 variants: method 9613 tags: pointwise 9614 9615- func: atan2(Tensor self, Tensor other) -> Tensor 9616 device_check: NoCheck # TensorIterator 9617 structured_delegate: atan2.out 9618 variants: method, function 9619 tags: [core, pointwise] 9620# arctan2, alias of atan2 9621 9622- func: arctan2(Tensor self, Tensor other) -> Tensor 9623 variants: method, function 9624 9625- func: arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9626 device_check: NoCheck # TensorIterator 9627 9628- func: arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) 9629 variants: method 9630 9631- func: lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!) 9632 device_check: NoCheck # TensorIterator 9633 structured: True 9634 structured_inherits: TensorIteratorBase 9635 dispatch: 9636 CPU, CUDA: lerp_Scalar 9637 MPS: lerp_Scalar_mps 9638 tags: pointwise 9639 9640- func: lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) 9641 device_check: NoCheck # TensorIterator 9642 structured: True 9643 structured_inherits: TensorIteratorBase 9644 dispatch: 9645 CPU, CUDA: lerp_Tensor 9646 MPS: lerp_Tensor_mps 9647 tags: pointwise 9648 9649- func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor 9650 device_check: NoCheck # TensorIterator 9651 variants: method, function 9652 structured_delegate: lerp.Scalar_out 9653 tags: pointwise 9654 9655- func: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor 9656 device_check: NoCheck # TensorIterator 9657 variants: method, function 9658 structured_delegate: lerp.Tensor_out 9659 tags: pointwise 9660 9661- func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) 9662 dispatch: 9663 CPU, MPS: histogram_histc_out 9664 CUDA: _histc_out_cuda 9665 9666- func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor 9667 variants: method, function 9668 dispatch: 9669 CPU, MPS: histogram_histc 9670 CUDA: _histc_cuda 9671 9672- func: histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) 9673 dispatch: 9674 CPU, MPS: histogram_out 9675 9676- func: histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) 9677 variants: method, function 9678 dispatch: 9679 CPU, MPS: histogram 9680 9681- func: histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) 9682 dispatch: 9683 CPU, MPS: histogram_out 9684 9685- func: histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) 9686 variants: method, function 9687 dispatch: 9688 CPU, MPS: histogram 9689 9690- func: _histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[] 9691 dispatch: 9692 CPU, MPS: histogramdd_bin_edges 9693 autogen: _histogramdd_bin_edges.out 9694 9695- func: _histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor 9696 dispatch: 9697 CPU, MPS: _histogramdd 9698 autogen: _histogramdd_from_bin_cts.out 9699 9700- func: _histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor 9701 dispatch: 9702 CPU, MPS: _histogramdd 9703 autogen: _histogramdd_from_bin_tensors.out 9704 9705- func: histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) 9706 9707- func: histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) 9708 9709- func: histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) 9710 9711- func: fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 9712 device_check: NoCheck # TensorIterator 9713 dispatch: 9714 CompositeExplicitAutograd: fmod_out 9715 tags: pointwise 9716 9717- func: fmod.Scalar(Tensor self, Scalar other) -> Tensor 9718 device_check: NoCheck # TensorIterator 9719 variants: method, function 9720 dispatch: 9721 CompositeExplicitAutograd: fmod 9722 tags: [core, pointwise] 9723 9724- func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 9725 device_check: NoCheck # TensorIterator 9726 variants: method 9727 dispatch: 9728 CompositeExplicitAutograd: fmod_ 9729 tags: pointwise 9730 9731- func: fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9732 device_check: NoCheck # TensorIterator 9733 structured: True 9734 structured_inherits: TensorIteratorBase 9735 dispatch: 9736 CPU, CUDA: fmod_out 9737 MPS: fmod_mps_out 9738 tags: pointwise 9739 9740- func: fmod.Tensor(Tensor self, Tensor other) -> Tensor 9741 device_check: NoCheck # TensorIterator 9742 structured_delegate: fmod.Tensor_out 9743 variants: method, function 9744 tags: [core, pointwise] 9745 9746- func: fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 9747 device_check: NoCheck # TensorIterator 9748 variants: method 9749 structured_delegate: fmod.Tensor_out 9750 tags: pointwise 9751 9752- func: hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9753 structured: True 9754 structured_inherits: TensorIteratorBase 9755 dispatch: 9756 CPU, CUDA: hypot_out 9757 MPS: hypot_out_mps 9758 tags: pointwise 9759 9760- func: hypot(Tensor self, Tensor other) -> Tensor 9761 structured_delegate: hypot.out 9762 variants: method, function 9763 tags: pointwise 9764 9765- func: hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) 9766 structured_delegate: hypot.out 9767 variants: method 9768 tags: pointwise 9769 9770- func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9771 structured: True 9772 structured_inherits: TensorIteratorBase 9773 dispatch: 9774 CPU, CUDA: igamma_out 9775 tags: pointwise 9776 9777- func: igamma(Tensor self, Tensor other) -> Tensor 9778 structured_delegate: igamma.out 9779 variants: method, function 9780 tags: pointwise 9781 9782- func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) 9783 structured_delegate: igamma.out 9784 variants: method 9785 tags: pointwise 9786 9787- func: igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9788 structured: True 9789 structured_inherits: TensorIteratorBase 9790 dispatch: 9791 CPU, CUDA: igammac_out 9792 tags: pointwise 9793 9794- func: igammac(Tensor self, Tensor other) -> Tensor 9795 structured_delegate: igammac.out 9796 variants: method, function 9797 tags: pointwise 9798 9799- func: igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!) 9800 structured_delegate: igammac.out 9801 variants: method 9802 tags: pointwise 9803 9804- func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9805 structured: True 9806 structured_inherits: TensorIteratorBase 9807 dispatch: 9808 CPU, CUDA, MPS: nextafter_out 9809 tags: pointwise 9810 9811- func: nextafter(Tensor self, Tensor other) -> Tensor 9812 structured_delegate: nextafter.out 9813 variants: method, function 9814 tags: pointwise 9815 9816- func: nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) 9817 structured_delegate: nextafter.out 9818 variants: method 9819 tags: pointwise 9820 9821- func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 9822 dispatch: 9823 CompositeExplicitAutograd: remainder_out 9824 tags: pointwise 9825 9826- func: remainder.Scalar(Tensor self, Scalar other) -> Tensor 9827 variants: method, function 9828 dispatch: 9829 CompositeExplicitAutograd: remainder 9830 tags: [core, pointwise] 9831 9832- func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) 9833 variants: method 9834 dispatch: 9835 CompositeExplicitAutograd: remainder_ 9836 tags: pointwise 9837 9838- func: remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9839 device_check: NoCheck # TensorIterator 9840 structured: True 9841 structured_inherits: TensorIteratorBase 9842 dispatch: 9843 CPU, CUDA: remainder_out 9844 MPS: remainder_out_mps 9845 tags: pointwise 9846 9847- func: remainder.Tensor(Tensor self, Tensor other) -> Tensor 9848 device_check: NoCheck # TensorIterator 9849 structured_delegate: remainder.Tensor_out 9850 variants: method, function 9851 tags: [core, pointwise] 9852 9853- func: remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) 9854 device_check: NoCheck # TensorIterator 9855 structured_delegate: remainder.Tensor_out 9856 variants: method 9857 tags: pointwise 9858 9859- func: remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor 9860 device_check: NoCheck # TensorIterator 9861 variants: function 9862 dispatch: 9863 CPU, CUDA, MPS: remainder 9864 autogen: remainder.Scalar_Tensor_out 9865 tags: pointwise 9866 9867- func: min(Tensor self) -> Tensor 9868 device_check: NoCheck # TensorIterator 9869 variants: method, function 9870 dispatch: 9871 CPU, CUDA: min 9872 MPS: min_mps 9873 QuantizedCPU: min_quantized_cpu 9874 9875- func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9876 device_check: NoCheck # TensorIterator 9877 dispatch: 9878 CPU, CUDA: min_unary_out 9879 QuantizedCPU: min_quantized_unary_out 9880 9881- func: fmin(Tensor self, Tensor other) -> Tensor 9882 structured_delegate: fmin.out 9883 device_check: NoCheck # TensorIterator 9884 variants: method, function 9885 tags: pointwise 9886 9887- func: fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9888 structured: True 9889 structured_inherits: TensorIteratorBase 9890 device_check: NoCheck # TensorIterator 9891 dispatch: 9892 CPU, CUDA, MPS: fmin_out 9893 tags: pointwise 9894 9895- func: max(Tensor self) -> Tensor 9896 device_check: NoCheck # TensorIterator 9897 variants: method, function 9898 dispatch: 9899 CPU, CUDA: max 9900 MPS: max_mps 9901 QuantizedCPU: max_quantized_cpu 9902 9903- func: fmax(Tensor self, Tensor other) -> Tensor 9904 structured_delegate: fmax.out 9905 device_check: NoCheck # TensorIterator 9906 variants: method, function 9907 tags: pointwise 9908 9909- func: fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9910 structured: True 9911 structured_inherits: TensorIteratorBase 9912 device_check: NoCheck # TensorIterator 9913 dispatch: 9914 CPU, CUDA, MPS: fmax_out 9915 tags: pointwise 9916 9917- func: maximum(Tensor self, Tensor other) -> Tensor 9918 structured_delegate: maximum.out 9919 device_check: NoCheck # TensorIterator 9920 variants: method, function 9921 tags: [core, pointwise] 9922 9923- func: maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9924 structured: True 9925 structured_inherits: TensorIteratorBase 9926 device_check: NoCheck # TensorIterator 9927 dispatch: 9928 CPU, CUDA: maximum_out 9929 MPS: maximum_out_mps 9930 tags: pointwise 9931 9932# binary max, alias of maximum 9933# NOTE: max is not an alias for maximum, since there is also unary max 9934- func: max.other(Tensor self, Tensor other) -> Tensor 9935 device_check: NoCheck # TensorIterator 9936 variants: method, function 9937 tags: pointwise 9938 9939- func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9940 device_check: NoCheck # TensorIterator 9941 tags: pointwise 9942 9943- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 9944 device_check: NoCheck # TensorIterator 9945 dispatch: 9946 CPU, CUDA: max_unary_out 9947 QuantizedCPU: max_quantized_unary_out 9948 9949- func: minimum(Tensor self, Tensor other) -> Tensor 9950 structured_delegate: minimum.out 9951 device_check: NoCheck # TensorIterator 9952 variants: method, function 9953 tags: [core, pointwise] 9954 9955- func: minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9956 structured: True 9957 structured_inherits: TensorIteratorBase 9958 device_check: NoCheck # TensorIterator 9959 dispatch: 9960 CPU, CUDA: minimum_out 9961 MPS: minimum_out_mps 9962 tags: pointwise 9963 9964# binary min, alias for minimum 9965# NOTE: min is not an alias for minimum, since there is also unary min 9966- func: min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 9967 device_check: NoCheck # TensorIterator 9968 tags: pointwise 9969 9970- func: min.other(Tensor self, Tensor other) -> Tensor 9971 device_check: NoCheck # TensorIterator 9972 variants: method, function 9973 tags: pointwise 9974 9975- func: quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor 9976 variants: method, function 9977 9978- func: quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) 9979 9980- func: quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor 9981 variants: method, function 9982 9983- func: quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) 9984 9985- func: nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor 9986 variants: method, function 9987 9988- func: nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) 9989 9990- func: nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor 9991 variants: method, function 9992 9993- func: nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) 9994 9995- func: sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 9996 device_check: NoCheck # TensorIterator 9997 dispatch: 9998 CompositeExplicitAutograd: sort_out 9999 10000- func: sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 10001 structured: True 10002 dispatch: 10003 CPU, CUDA: sort_stable_out 10004 MPS: sort_stable_out_mps 10005 10006- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) 10007 device_check: NoCheck # TensorIterator 10008 variants: method, function 10009 dispatch: 10010 CompositeExplicitAutograd: sort 10011 tags: core 10012 10013- func: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) 10014 structured_delegate: sort.values_stable 10015 variants: method, function 10016 dispatch: 10017 QuantizedCPU: sort_quantized_cpu_stable 10018 10019- func: sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 10020 10021- func: sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 10022 10023- func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) 10024 variants: method, function 10025 10026- func: sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) 10027 variants: method, function 10028 10029- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 10030 10031- func: msort(Tensor self) -> Tensor 10032 variants: method, function 10033 10034- func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor 10035 device_check: NoCheck # TensorIterator 10036 variants: method, function 10037 10038- func: argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor 10039 device_check: NoCheck # TensorIterator 10040 variants: method, function 10041 dispatch: 10042 CPU, CUDA, MPS: argsort_stable 10043 autogen: argsort.stable_out 10044 10045- func: argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor 10046 variants: method, function 10047 10048- func: topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) 10049 structured: True 10050 dispatch: 10051 CPU: topk_out_cpu 10052 CUDA: topk_out_cuda 10053 MPS: topk_out_mps 10054 10055- func: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) 10056 variants: method, function 10057 structured_delegate: topk.values 10058 dispatch: 10059 QuantizedCPU: topk_quantized_cpu 10060 tags: core 10061 10062- func: all(Tensor self) -> Tensor 10063 device_check: NoCheck # TensorIterator 10064 structured_delegate: all.all_out 10065 variants: method, function 10066 10067- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 10068 device_check: NoCheck 10069 structured: True 10070 dispatch: 10071 CPU, CUDA: all_all_out 10072 MPS: all_all_out_mps 10073 10074- func: any(Tensor self) -> Tensor 10075 device_check: NoCheck # TensorIterator 10076 structured_delegate: any.all_out 10077 variants: method, function 10078 dispatch: 10079 SparseCPU, SparseCUDA: any_sparse 10080 tags: core 10081 10082- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 10083 device_check: NoCheck 10084 structured: True 10085 dispatch: 10086 CPU, CUDA: any_all_out 10087 MPS: any_all_out_mps 10088 10089- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) 10090 device_check: NoCheck # TensorIterator 10091 structured: True 10092 dispatch: 10093 CPU, CUDA: renorm_out 10094 MPS: renorm_out_mps 10095 10096- func: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor 10097 device_check: NoCheck # TensorIterator 10098 variants: method, function 10099 structured_delegate: renorm.out 10100 10101- func: renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!) 10102 device_check: NoCheck # TensorIterator 10103 variants: method 10104 structured_delegate: renorm.out 10105 10106- func: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) 10107 variants: method 10108 device_check: NoCheck 10109 device_guard: False 10110 dispatch: 10111 CPU, CUDA, Meta, MPS: unfold 10112 QuantizedCPU, QuantizedCUDA: unfold 10113 10114- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor 10115 variants: function 10116 dispatch: 10117 CPU, CUDA: unfold_backward 10118 autogen: unfold_backward.out 10119 10120- func: equal(Tensor self, Tensor other) -> bool 10121 tags: [data_dependent_output, pointwise] 10122 variants: method, function 10123 dispatch: 10124 CPU: cpu_equal 10125 CUDA: cuda_equal 10126 MPS: mps_equal 10127 QuantizedCPU: equal_quantized_cpu 10128 10129- func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) 10130 device_check: NoCheck # TensorIterator 10131 structured: True 10132 structured_inherits: TensorIteratorBase 10133 dispatch: 10134 CPU, CUDA: pow_Tensor_Tensor_out 10135 MPS: pow_tensor_tensor_out_mps 10136 tags: pointwise 10137 10138- func: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor 10139 device_check: NoCheck # TensorIterator 10140 structured_delegate: pow.Tensor_Tensor_out 10141 variants: method, function 10142 tags: [core, pointwise] 10143 10144- func: pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) 10145 device_check: NoCheck # TensorIterator 10146 structured: True 10147 dispatch: 10148 CPU, CUDA: pow_Scalar_out 10149 MPS: pow_Scalar_out_mps 10150 tags: pointwise 10151 10152- func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor 10153 device_check: NoCheck # TensorIterator 10154 structured_delegate: pow.Scalar_out 10155 tags: [core, pointwise] 10156 10157- func: pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) 10158 device_check: NoCheck # TensorIterator 10159 structured: True 10160 structured_inherits: TensorIteratorBase 10161 dispatch: 10162 CPU, CUDA: pow_Tensor_Scalar_out 10163 SparseCPU, SparseCUDA: pow_out_sparse_scalar 10164 MPS: pow_tensor_scalar_out_mps 10165 tags: pointwise 10166 10167- func: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor 10168 device_check: NoCheck # TensorIterator 10169 structured_delegate: pow.Tensor_Scalar_out 10170 variants: function, method 10171 dispatch: 10172 SparseCPU, SparseCUDA: pow_sparse_scalar 10173 tags: [core, pointwise] 10174 10175- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) 10176 device_check: NoCheck # TensorIterator 10177 structured_delegate: pow.Tensor_Scalar_out 10178 variants: method 10179 tags: pointwise 10180 10181- func: pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) 10182 device_check: NoCheck # TensorIterator 10183 structured_delegate: pow.Tensor_Tensor_out 10184 variants: method 10185 tags: pointwise 10186 10187- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) 10188 tags: pointwise 10189 10190- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor 10191 variants: function, method 10192 tags: pointwise 10193 10194- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) 10195 tags: pointwise 10196 10197- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor 10198 tags: pointwise 10199 10200- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) 10201 tags: pointwise 10202 10203- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor 10204 variants: function, method 10205 tags: pointwise 10206 10207- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) 10208 variants: method 10209 tags: pointwise 10210 10211- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) 10212 variants: method 10213 tags: pointwise 10214 10215- func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) 10216 device_check: NoCheck # TensorIterator 10217 tags: nondeterministic_seeded 10218 variants: method 10219 dispatch: 10220 CPU, CUDA: normal_ 10221 MPS: normal_mps_ 10222 Meta: normal_meta_ 10223 SparseCsrCPU, SparseCsrCUDA: normal_sparse_csr_ 10224 NestedTensorCPU, NestedTensorCUDA: normal_nested_ 10225 autogen: normal.out 10226 10227# Only used by the functionalization pass. 10228# Normally, the codegen would be able to generate a normal() NativeFunction, 10229# but we can't due to overload ambiguity with normal.Tensor_float. 10230- func: normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor 10231 device_check: NoCheck # TensorIterator 10232 tags: nondeterministic_seeded 10233 dispatch: 10234 CompositeExplicitAutograd: normal_functional 10235 10236- func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) 10237 tags: nondeterministic_seeded 10238 dispatch: 10239 CPU, CUDA: normal_out 10240 MPS: normal_mps_out 10241 Meta: normal_out_meta 10242 10243- func: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor 10244 dispatch: 10245 CPU, CUDA: normal 10246 MPS: normal_mps 10247 Meta: normal_meta 10248 tags: nondeterministic_seeded 10249 10250- func: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) 10251 dispatch: 10252 CPU, CUDA: normal_out 10253 Meta: normal_out_meta 10254 MPS: normal_mps_out 10255 tags: nondeterministic_seeded 10256 10257- func: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor 10258 dispatch: 10259 CPU, CUDA: normal 10260 MPS: normal_mps 10261 Meta: normal_meta 10262 tags: nondeterministic_seeded 10263 10264- func: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) 10265 dispatch: 10266 CPU, CUDA: normal_out 10267 Meta: normal_out_meta 10268 MPS: normal_mps_out 10269 tags: nondeterministic_seeded 10270 10271- func: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor 10272 dispatch: 10273 CPU, CUDA: normal 10274 MPS: normal_mps 10275 Meta: normal_meta 10276 tags: nondeterministic_seeded 10277 10278- func: normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 10279 dispatch: 10280 CompositeExplicitAutograd: normal 10281 tags: nondeterministic_seeded 10282 10283- func: normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) 10284 dispatch: 10285 CompositeExplicitAutograd: normal_out 10286 tags: nondeterministic_seeded 10287 10288- func: alias(Tensor(a) self) -> Tensor(a) 10289 variants: method, function 10290 dispatch: 10291 CompositeExplicitAutograd: alias 10292 NestedTensorCPU, NestedTensorCUDA: alias_nested 10293 tags: core 10294 10295- func: _amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> () 10296 variants: function 10297 dispatch: 10298 CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_ 10299 CPU: _amp_foreach_non_finite_check_and_unscale_cpu_ 10300 autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out 10301 10302- func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) 10303 variants: function 10304 dispatch: 10305 CUDA: _amp_update_scale_cuda_ 10306 CPU: _amp_update_scale_cpu_ 10307 autogen: _amp_update_scale, _amp_update_scale.out 10308 10309 #- func: _cat(Tensor[] tensors, int dim=0) -> Tensor 10310 #dispatch: 10311 #CPU: _cat_cpu 10312 #CUDA: cat_cuda 10313 #MPS: cat_mps 10314 #QuantizedCPU: cat_quantized_cpu 10315 10316 #- func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) 10317 #dispatch: 10318 #CPU: _cat_out_cpu 10319 #CUDA: cat_out_cuda 10320 #QuantizedCPU: cat_out_quantized_cpu 10321 10322- func: _foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] 10323 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10324 variants: function 10325 dispatch: 10326 CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow 10327 CUDA: foreach_tensor_add_scalar_kernel_cuda 10328 10329- func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () 10330 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10331 variants: function 10332 dispatch: 10333 CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_ 10334 CUDA: foreach_tensor_add_scalar_kernel_cuda_ 10335 autogen: _foreach_add.Scalar_out 10336 10337- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] 10338 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10339 variants: function 10340 dispatch: 10341 CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow 10342 CUDA: foreach_tensor_add_list_kernel_cuda 10343 10344- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () 10345 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10346 variants: function 10347 dispatch: 10348 CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_ 10349 CUDA: foreach_tensor_add_list_kernel_cuda_ 10350 autogen: _foreach_add.List_out 10351 10352- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] 10353 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10354 variants: function 10355 dispatch: 10356 CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow 10357 CUDA: foreach_tensor_add_scalarlist_kernel_cuda 10358 10359- func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () 10360 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10361 variants: function 10362 dispatch: 10363 CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow_ 10364 CUDA: foreach_tensor_add_scalarlist_kernel_cuda_ 10365 autogen: _foreach_add.ScalarList_out 10366 10367- func: _foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[] 10368 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10369 variants: function 10370 dispatch: 10371 CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow 10372 CUDA: foreach_tensor_add_tensor_kernel_cuda 10373 10374- func: _foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> () 10375 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10376 variants: function 10377 dispatch: 10378 CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_ 10379 CUDA: foreach_tensor_add_tensor_kernel_cuda_ 10380 autogen: _foreach_add.Tensor_out 10381 10382- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] 10383 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10384 variants: function 10385 dispatch: 10386 CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow 10387 CUDA: foreach_tensor_sub_scalar_kernel_cuda 10388 10389- func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () 10390 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10391 variants: function 10392 dispatch: 10393 CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow_ 10394 CUDA: foreach_tensor_sub_scalar_kernel_cuda_ 10395 autogen: _foreach_sub.Scalar_out 10396 10397- func: _foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] 10398 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10399 variants: function 10400 dispatch: 10401 CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow 10402 CUDA: foreach_tensor_sub_list_kernel_cuda 10403 10404- func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () 10405 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10406 variants: function 10407 dispatch: 10408 CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow_ 10409 CUDA: foreach_tensor_sub_list_kernel_cuda_ 10410 autogen: _foreach_sub.List_out 10411 10412- func: _foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] 10413 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10414 variants: function 10415 dispatch: 10416 CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow 10417 CUDA: foreach_tensor_sub_scalarlist_kernel_cuda 10418 10419- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () 10420 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10421 variants: function 10422 dispatch: 10423 CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow_ 10424 CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_ 10425 autogen: _foreach_sub.ScalarList_out 10426 10427- func: _foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] 10428 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10429 variants: function 10430 dispatch: 10431 CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow 10432 CUDA: foreach_tensor_mul_scalar_kernel_cuda 10433 10434- func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () 10435 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10436 variants: function 10437 dispatch: 10438 CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_ 10439 CUDA: foreach_tensor_mul_scalar_kernel_cuda_ 10440 autogen: _foreach_mul.Scalar_out 10441 10442- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] 10443 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10444 variants: function 10445 dispatch: 10446 CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow 10447 CUDA: foreach_tensor_mul_list_kernel_cuda 10448 10449- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () 10450 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10451 variants: function 10452 dispatch: 10453 CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_ 10454 CUDA: foreach_tensor_mul_list_kernel_cuda_ 10455 autogen: _foreach_mul.List_out 10456 10457- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] 10458 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10459 variants: function 10460 dispatch: 10461 CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow 10462 CUDA: foreach_tensor_mul_scalarlist_kernel_cuda 10463 10464- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () 10465 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10466 variants: function 10467 dispatch: 10468 CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow_ 10469 CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_ 10470 autogen: _foreach_mul.ScalarList_out 10471 10472- func: _foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[] 10473 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10474 variants: function 10475 dispatch: 10476 CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow 10477 CUDA: foreach_tensor_mul_tensor_kernel_cuda 10478 10479- func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () 10480 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10481 variants: function 10482 dispatch: 10483 CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_ 10484 CUDA: foreach_tensor_mul_tensor_kernel_cuda_ 10485 autogen: _foreach_mul.Tensor_out 10486 10487- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] 10488 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10489 variants: function 10490 dispatch: 10491 CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow 10492 CUDA: foreach_tensor_div_scalar_kernel_cuda 10493 10494- func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () 10495 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10496 variants: function 10497 dispatch: 10498 CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow_ 10499 CUDA: foreach_tensor_div_scalar_kernel_cuda_ 10500 autogen: _foreach_div.Scalar_out 10501 10502- func: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] 10503 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10504 variants: function 10505 dispatch: 10506 CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow 10507 CUDA: foreach_tensor_div_list_kernel_cuda 10508 10509- func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () 10510 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10511 variants: function 10512 dispatch: 10513 CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow_ 10514 CUDA: foreach_tensor_div_list_kernel_cuda_ 10515 autogen: _foreach_div.List_out 10516 10517- func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] 10518 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10519 variants: function 10520 dispatch: 10521 CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow 10522 CUDA: foreach_tensor_div_scalarlist_kernel_cuda 10523 10524- func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () 10525 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10526 variants: function 10527 dispatch: 10528 CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow_ 10529 CUDA: foreach_tensor_div_scalarlist_kernel_cuda_ 10530 autogen: _foreach_div.ScalarList_out 10531 10532- func: _foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[] 10533 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10534 variants: function 10535 dispatch: 10536 CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow 10537 CUDA: foreach_tensor_div_tensor_kernel_cuda 10538 10539- func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () 10540 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10541 variants: function 10542 dispatch: 10543 CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow_ 10544 CUDA: foreach_tensor_div_tensor_kernel_cuda_ 10545 autogen: _foreach_div.Tensor_out 10546 10547- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] 10548 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10549 variants: function 10550 dispatch: 10551 CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow 10552 CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda 10553 10554- func: _foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () 10555 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10556 variants: function 10557 dispatch: 10558 CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_ 10559 CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ 10560 autogen: _foreach_clamp_max.Scalar_out 10561 10562- func: _foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[] 10563 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10564 variants: function 10565 dispatch: 10566 CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow 10567 CUDA: foreach_tensor_clamp_max_list_kernel_cuda 10568 10569- func: _foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> () 10570 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10571 variants: function 10572 dispatch: 10573 CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_ 10574 CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ 10575 autogen: _foreach_clamp_max.List_out 10576 10577- func: _foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] 10578 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10579 variants: function 10580 dispatch: 10581 CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow 10582 CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda 10583 10584- func: _foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () 10585 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10586 variants: function 10587 dispatch: 10588 CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_ 10589 CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ 10590 autogen: _foreach_clamp_max.ScalarList_out 10591 10592- func: _foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] 10593 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10594 variants: function 10595 dispatch: 10596 CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow 10597 CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda 10598 10599- func: _foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () 10600 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10601 variants: function 10602 dispatch: 10603 CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_ 10604 CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ 10605 autogen: _foreach_clamp_min.Scalar_out 10606 10607- func: _foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[] 10608 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10609 variants: function 10610 dispatch: 10611 CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow 10612 CUDA: foreach_tensor_clamp_min_list_kernel_cuda 10613 10614- func: _foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> () 10615 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10616 variants: function 10617 dispatch: 10618 CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_ 10619 CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ 10620 autogen: _foreach_clamp_min.List_out 10621 10622- func: _foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] 10623 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10624 variants: function 10625 dispatch: 10626 CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow 10627 CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda 10628 10629- func: _foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () 10630 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10631 variants: function 10632 dispatch: 10633 CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_ 10634 CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ 10635 autogen: _foreach_clamp_min.ScalarList_out 10636 10637# foreach_minimum/maximum dispatches to clamp_max/min 10638- func: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] 10639 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10640 variants: function 10641 dispatch: 10642 CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow 10643 CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda 10644 10645- func: _foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () 10646 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10647 variants: function 10648 dispatch: 10649 CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_ 10650 CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_ 10651 autogen: _foreach_maximum.Scalar_out 10652 10653# foreach_minimum/maximum dispatches to clamp_max/min 10654- func: _foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[] 10655 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10656 variants: function 10657 dispatch: 10658 CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow 10659 CUDA: foreach_tensor_clamp_min_list_kernel_cuda 10660 10661- func: _foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> () 10662 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10663 variants: function 10664 dispatch: 10665 CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_ 10666 CUDA: foreach_tensor_clamp_min_list_kernel_cuda_ 10667 autogen: _foreach_maximum.List_out 10668 10669# foreach_minimum/maximum dispatches to clamp_max/min 10670- func: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] 10671 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10672 variants: function 10673 dispatch: 10674 CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow 10675 CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda 10676 10677- func: _foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () 10678 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10679 variants: function 10680 dispatch: 10681 CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_ 10682 CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_ 10683 autogen: _foreach_maximum.ScalarList_out 10684 10685- func: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] 10686 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10687 variants: function 10688 dispatch: 10689 CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow 10690 CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda 10691 10692- func: _foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () 10693 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10694 variants: function 10695 dispatch: 10696 CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_ 10697 CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_ 10698 autogen: _foreach_minimum.Scalar_out 10699 10700- func: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[] 10701 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10702 variants: function 10703 dispatch: 10704 CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow 10705 CUDA: foreach_tensor_clamp_max_list_kernel_cuda 10706 10707- func: _foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> () 10708 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10709 variants: function 10710 dispatch: 10711 CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_ 10712 CUDA: foreach_tensor_clamp_max_list_kernel_cuda_ 10713 autogen: _foreach_minimum.List_out 10714 10715- func: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] 10716 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10717 variants: function 10718 dispatch: 10719 CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow 10720 CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda 10721 10722- func: _foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () 10723 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10724 variants: function 10725 dispatch: 10726 CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_ 10727 CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_ 10728 autogen: _foreach_minimum.ScalarList_out 10729 10730- func: _foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] 10731 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10732 variants: function 10733 dispatch: 10734 CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow 10735 CUDA: foreach_tensor_addcdiv_scalar_cuda 10736 10737- func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] 10738 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10739 variants: function 10740 dispatch: 10741 CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow 10742 CUDA: foreach_tensor_addcdiv_scalarlist_cuda 10743 10744- func: _foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] 10745 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10746 variants: function 10747 dispatch: 10748 CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow 10749 CUDA: foreach_tensor_addcdiv_tensor_cuda 10750 10751- func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () 10752 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10753 variants: function 10754 dispatch: 10755 CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow_ 10756 CUDA: foreach_tensor_addcdiv_scalar_cuda_ 10757 autogen: _foreach_addcdiv.Scalar_out 10758 10759- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () 10760 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10761 variants: function 10762 dispatch: 10763 CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow_ 10764 CUDA: foreach_tensor_addcdiv_scalarlist_cuda_ 10765 autogen: _foreach_addcdiv.ScalarList_out 10766 10767- func: _foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () 10768 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10769 variants: function 10770 dispatch: 10771 CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow_ 10772 CUDA: foreach_tensor_addcdiv_tensor_cuda_ 10773 autogen: _foreach_addcdiv.Tensor_out 10774 10775- func: _foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] 10776 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10777 variants: function 10778 dispatch: 10779 CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow 10780 CUDA: foreach_tensor_addcmul_scalar_cuda 10781 10782- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] 10783 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10784 variants: function 10785 dispatch: 10786 CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow 10787 CUDA: foreach_tensor_addcmul_scalarlist_cuda 10788 10789- func: _foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] 10790 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10791 variants: function 10792 dispatch: 10793 CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow 10794 CUDA: foreach_tensor_addcmul_tensor_cuda 10795 10796- func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () 10797 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10798 variants: function 10799 dispatch: 10800 CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_ 10801 CUDA: foreach_tensor_addcmul_scalar_cuda_ 10802 autogen: _foreach_addcmul.Scalar_out 10803 10804- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () 10805 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10806 variants: function 10807 dispatch: 10808 CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow_ 10809 CUDA: foreach_tensor_addcmul_scalarlist_cuda_ 10810 autogen: _foreach_addcmul.ScalarList_out 10811 10812- func: _foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () 10813 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10814 variants: function 10815 dispatch: 10816 CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow_ 10817 CUDA: foreach_tensor_addcmul_tensor_cuda_ 10818 autogen: _foreach_addcmul.Tensor_out 10819 10820- func: _foreach_abs(Tensor[] self) -> Tensor[] 10821 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10822 variants: function 10823 dispatch: 10824 CompositeExplicitAutograd: foreach_tensor_abs_slow 10825 CUDA: foreach_tensor_abs_cuda 10826 10827- func: _foreach_abs_(Tensor(a!)[] self) -> () 10828 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10829 variants: function 10830 dispatch: 10831 CompositeExplicitAutograd: foreach_tensor_abs_slow_ 10832 CUDA: foreach_tensor_abs_cuda_ 10833 autogen: _foreach_abs.out 10834 10835- func: _foreach_acos(Tensor[] self) -> Tensor[] 10836 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10837 variants: function 10838 dispatch: 10839 CompositeExplicitAutograd: foreach_tensor_acos_slow 10840 CUDA: foreach_tensor_acos_cuda 10841 10842- func: _foreach_acos_(Tensor(a!)[] self) -> () 10843 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10844 variants: function 10845 dispatch: 10846 CompositeExplicitAutograd: foreach_tensor_acos_slow_ 10847 CUDA: foreach_tensor_acos_cuda_ 10848 autogen: _foreach_acos.out 10849 10850- func: _foreach_asin(Tensor[] self) -> Tensor[] 10851 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10852 variants: function 10853 dispatch: 10854 CompositeExplicitAutograd: foreach_tensor_asin_slow 10855 CUDA: foreach_tensor_asin_cuda 10856 10857- func: _foreach_asin_(Tensor(a!)[] self) -> () 10858 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10859 variants: function 10860 dispatch: 10861 CompositeExplicitAutograd: foreach_tensor_asin_slow_ 10862 CUDA: foreach_tensor_asin_cuda_ 10863 autogen: _foreach_asin.out 10864 10865- func: _foreach_atan(Tensor[] self) -> Tensor[] 10866 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10867 variants: function 10868 dispatch: 10869 CompositeExplicitAutograd: foreach_tensor_atan_slow 10870 CUDA: foreach_tensor_atan_cuda 10871 10872- func: _foreach_atan_(Tensor(a!)[] self) -> () 10873 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10874 variants: function 10875 dispatch: 10876 CompositeExplicitAutograd: foreach_tensor_atan_slow_ 10877 CUDA: foreach_tensor_atan_cuda_ 10878 autogen: _foreach_atan.out 10879 10880- func: _foreach_ceil(Tensor[] self) -> Tensor[] 10881 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10882 variants: function 10883 dispatch: 10884 CompositeExplicitAutograd: foreach_tensor_ceil_slow 10885 CUDA: foreach_tensor_ceil_cuda 10886 10887- func: _foreach_ceil_(Tensor(a!)[] self) -> () 10888 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10889 variants: function 10890 dispatch: 10891 CompositeExplicitAutograd: foreach_tensor_ceil_slow_ 10892 CUDA: foreach_tensor_ceil_cuda_ 10893 autogen: _foreach_ceil.out 10894 10895- func: _foreach_cos(Tensor[] self) -> Tensor[] 10896 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10897 variants: function 10898 dispatch: 10899 CompositeExplicitAutograd: foreach_tensor_cos_slow 10900 CUDA: foreach_tensor_cos_cuda 10901 10902- func: _foreach_cos_(Tensor(a!)[] self) -> () 10903 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10904 variants: function 10905 dispatch: 10906 CompositeExplicitAutograd: foreach_tensor_cos_slow_ 10907 CUDA: foreach_tensor_cos_cuda_ 10908 autogen: _foreach_cos.out 10909 10910- func: _foreach_cosh(Tensor[] self) -> Tensor[] 10911 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10912 variants: function 10913 dispatch: 10914 CompositeExplicitAutograd: foreach_tensor_cosh_slow 10915 CUDA: foreach_tensor_cosh_cuda 10916 10917- func: _foreach_cosh_(Tensor(a!)[] self) -> () 10918 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10919 variants: function 10920 dispatch: 10921 CompositeExplicitAutograd: foreach_tensor_cosh_slow_ 10922 CUDA: foreach_tensor_cosh_cuda_ 10923 autogen: _foreach_cosh.out 10924 10925- func: _foreach_erf(Tensor[] self) -> Tensor[] 10926 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10927 variants: function 10928 dispatch: 10929 CompositeExplicitAutograd: foreach_tensor_erf_slow 10930 CUDA: foreach_tensor_erf_cuda 10931 10932- func: _foreach_erf_(Tensor(a!)[] self) -> () 10933 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10934 variants: function 10935 dispatch: 10936 CompositeExplicitAutograd: foreach_tensor_erf_slow_ 10937 CUDA: foreach_tensor_erf_cuda_ 10938 autogen: _foreach_erf.out 10939 10940- func: _foreach_erfc(Tensor[] self) -> Tensor[] 10941 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10942 variants: function 10943 dispatch: 10944 CompositeExplicitAutograd: foreach_tensor_erfc_slow 10945 CUDA: foreach_tensor_erfc_cuda 10946 10947- func: _foreach_erfc_(Tensor(a!)[] self) -> () 10948 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10949 variants: function 10950 dispatch: 10951 CompositeExplicitAutograd: foreach_tensor_erfc_slow_ 10952 CUDA: foreach_tensor_erfc_cuda_ 10953 autogen: _foreach_erfc.out 10954 10955- func: _foreach_exp(Tensor[] self) -> Tensor[] 10956 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10957 variants: function 10958 dispatch: 10959 CompositeExplicitAutograd: foreach_tensor_exp_slow 10960 CUDA: foreach_tensor_exp_cuda 10961 10962- func: _foreach_exp_(Tensor(a!)[] self) -> () 10963 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10964 variants: function 10965 dispatch: 10966 CompositeExplicitAutograd: foreach_tensor_exp_slow_ 10967 CUDA: foreach_tensor_exp_cuda_ 10968 autogen: _foreach_exp.out 10969 10970- func: _foreach_expm1(Tensor[] self) -> Tensor[] 10971 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10972 variants: function 10973 dispatch: 10974 CompositeExplicitAutograd: foreach_tensor_expm1_slow 10975 CUDA: foreach_tensor_expm1_cuda 10976 10977- func: _foreach_expm1_(Tensor(a!)[] self) -> () 10978 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10979 variants: function 10980 dispatch: 10981 CompositeExplicitAutograd: foreach_tensor_expm1_slow_ 10982 CUDA: foreach_tensor_expm1_cuda_ 10983 autogen: _foreach_expm1.out 10984 10985- func: _foreach_floor(Tensor[] self) -> Tensor[] 10986 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10987 variants: function 10988 dispatch: 10989 CompositeExplicitAutograd: foreach_tensor_floor_slow 10990 CUDA: foreach_tensor_floor_cuda 10991 10992- func: _foreach_floor_(Tensor(a!)[] self) -> () 10993 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 10994 variants: function 10995 dispatch: 10996 CompositeExplicitAutograd: foreach_tensor_floor_slow_ 10997 CUDA: foreach_tensor_floor_cuda_ 10998 autogen: _foreach_floor.out 10999 11000- func: _foreach_frac(Tensor[] self) -> Tensor[] 11001 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11002 variants: function 11003 dispatch: 11004 CompositeExplicitAutograd: foreach_tensor_frac_slow 11005 CUDA: foreach_tensor_frac_cuda 11006 11007- func: _foreach_frac_(Tensor(a!)[] self) -> () 11008 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11009 variants: function 11010 dispatch: 11011 CompositeExplicitAutograd: foreach_tensor_frac_slow_ 11012 CUDA: foreach_tensor_frac_cuda_ 11013 autogen: _foreach_frac.out 11014 11015- func: _foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[] 11016 device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices 11017 variants: function 11018 dispatch: 11019 CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow 11020 CUDA: foreach_tensor_lerp_ternary_cuda 11021 autogen: _foreach_lerp.List_out 11022 11023- func: _foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> () 11024 device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices 11025 variants: function 11026 dispatch: 11027 CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow_ 11028 CUDA: foreach_tensor_lerp_ternary_cuda_ 11029 autogen: _foreach_lerp.List_out 11030 11031- func: _foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[] 11032 device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices 11033 variants: function 11034 dispatch: 11035 CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow 11036 CUDA: foreach_tensor_lerp_list_cuda 11037 autogen: _foreach_lerp.Scalar_out 11038 11039- func: _foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> () 11040 device_check: NoCheck # foreach kernels fall back to slow path when tensors are on different devices 11041 variants: function 11042 dispatch: 11043 CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow_ 11044 CUDA: foreach_tensor_lerp_list_cuda_ 11045 autogen: _foreach_lerp.Scalar_out 11046 11047- func: _foreach_lgamma(Tensor[] self) -> Tensor[] 11048 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11049 variants: function 11050 dispatch: 11051 CompositeExplicitAutograd: foreach_tensor_lgamma_slow 11052 CUDA: foreach_tensor_lgamma_cuda 11053 11054- func: _foreach_lgamma_(Tensor(a!)[] self) -> () 11055 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11056 variants: function 11057 dispatch: 11058 CompositeExplicitAutograd: foreach_tensor_lgamma_slow_ 11059 CUDA: foreach_tensor_lgamma_cuda_ 11060 autogen: _foreach_lgamma.out 11061 11062- func: _foreach_log(Tensor[] self) -> Tensor[] 11063 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11064 variants: function 11065 dispatch: 11066 CompositeExplicitAutograd: foreach_tensor_log_slow 11067 CUDA: foreach_tensor_log_cuda 11068 11069- func: _foreach_log_(Tensor(a!)[] self) -> () 11070 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11071 variants: function 11072 dispatch: 11073 CompositeExplicitAutograd: foreach_tensor_log_slow_ 11074 CUDA: foreach_tensor_log_cuda_ 11075 autogen: _foreach_log.out 11076 11077- func: _foreach_log10(Tensor[] self) -> Tensor[] 11078 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11079 variants: function 11080 dispatch: 11081 CompositeExplicitAutograd: foreach_tensor_log10_slow 11082 CUDA: foreach_tensor_log10_cuda 11083 11084- func: _foreach_log10_(Tensor(a!)[] self) -> () 11085 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11086 variants: function 11087 dispatch: 11088 CompositeExplicitAutograd: foreach_tensor_log10_slow_ 11089 CUDA: foreach_tensor_log10_cuda_ 11090 autogen: _foreach_log10.out 11091 11092- func: _foreach_log1p(Tensor[] self) -> Tensor[] 11093 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11094 variants: function 11095 dispatch: 11096 CompositeExplicitAutograd: foreach_tensor_log1p_slow 11097 CUDA: foreach_tensor_log1p_cuda 11098 11099- func: _foreach_log1p_(Tensor(a!)[] self) -> () 11100 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11101 variants: function 11102 dispatch: 11103 CompositeExplicitAutograd: foreach_tensor_log1p_slow_ 11104 CUDA: foreach_tensor_log1p_cuda_ 11105 autogen: _foreach_log1p.out 11106 11107- func: _foreach_log2(Tensor[] self) -> Tensor[] 11108 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11109 variants: function 11110 dispatch: 11111 CompositeExplicitAutograd: foreach_tensor_log2_slow 11112 CUDA: foreach_tensor_log2_cuda 11113 11114- func: _foreach_log2_(Tensor(a!)[] self) -> () 11115 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11116 variants: function 11117 dispatch: 11118 CompositeExplicitAutograd: foreach_tensor_log2_slow_ 11119 CUDA: foreach_tensor_log2_cuda_ 11120 autogen: _foreach_log2.out 11121 11122- func: _foreach_max(Tensor[] self) -> Tensor[] 11123 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11124 variants: function 11125 dispatch: 11126 CompositeExplicitAutograd: foreach_tensor_max_slow 11127 CUDA: foreach_tensor_max_cuda 11128 autogen: _foreach_max.out 11129 11130- func: _foreach_neg(Tensor[] self) -> Tensor[] 11131 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11132 variants: function 11133 dispatch: 11134 CompositeExplicitAutograd: foreach_tensor_neg_slow 11135 CUDA: foreach_tensor_neg_cuda 11136 11137- func: _foreach_neg_(Tensor(a!)[] self) -> () 11138 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11139 variants: function 11140 dispatch: 11141 CompositeExplicitAutograd: foreach_tensor_neg_slow_ 11142 CUDA: foreach_tensor_neg_cuda_ 11143 autogen: _foreach_neg.out 11144 11145- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[] 11146 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11147 variants: function 11148 dispatch: 11149 CompositeExplicitAutograd: foreach_tensor_norm_slow 11150 CUDA: foreach_tensor_norm_cuda 11151 autogen: _foreach_norm.Scalar_out 11152 11153- func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] 11154 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11155 variants: function 11156 dispatch: 11157 CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow 11158 CUDA: foreach_tensor_pow_list_kernel_cuda 11159 11160- func: _foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[] 11161 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11162 variants: function 11163 dispatch: 11164 CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow 11165 CUDA: foreach_tensor_pow_scalar_kernel_cuda 11166 11167- func: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] 11168 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11169 variants: function 11170 dispatch: 11171 CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow 11172 CUDA: foreach_tensor_pow_scalarlist_kernel_cuda 11173 11174- func: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] 11175 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11176 variants: function 11177 dispatch: 11178 CompositeExplicitAutograd: foreach_scalar_pow_list_kernel_slow 11179 CUDA: foreach_scalar_pow_list_kernel_cuda 11180 11181- func: _foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> () 11182 device_check: NoCheck 11183 variants: function 11184 dispatch: 11185 CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow_ 11186 CUDA: foreach_tensor_pow_list_kernel_cuda_ 11187 autogen: _foreach_pow.List_out 11188 11189- func: _foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> () 11190 device_check: NoCheck 11191 variants: function 11192 dispatch: 11193 CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow_ 11194 CUDA: foreach_tensor_pow_scalar_kernel_cuda_ 11195 autogen: _foreach_pow.Scalar_out 11196 11197- func: _foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> () 11198 device_check: NoCheck 11199 variants: function 11200 dispatch: 11201 CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow_ 11202 CUDA: foreach_tensor_pow_scalarlist_kernel_cuda_ 11203 autogen: _foreach_pow.ScalarList_out 11204 11205- func: _foreach_reciprocal(Tensor[] self) -> Tensor[] 11206 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11207 variants: function 11208 dispatch: 11209 CompositeExplicitAutograd: foreach_tensor_reciprocal_slow 11210 CUDA: foreach_tensor_reciprocal_cuda 11211 11212- func: _foreach_reciprocal_(Tensor(a!)[] self) -> () 11213 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11214 variants: function 11215 dispatch: 11216 CompositeExplicitAutograd: foreach_tensor_reciprocal_slow_ 11217 CUDA: foreach_tensor_reciprocal_cuda_ 11218 autogen: _foreach_reciprocal.out 11219 11220- func: _foreach_round(Tensor[] self) -> Tensor[] 11221 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11222 variants: function 11223 dispatch: 11224 CompositeExplicitAutograd: foreach_tensor_round_slow 11225 CUDA: foreach_tensor_round_cuda 11226 11227- func: _foreach_round_(Tensor(a!)[] self) -> () 11228 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11229 variants: function 11230 dispatch: 11231 CompositeExplicitAutograd: foreach_tensor_round_slow_ 11232 CUDA: foreach_tensor_round_cuda_ 11233 autogen: _foreach_round.out 11234 11235- func: _foreach_sigmoid(Tensor[] self) -> Tensor[] 11236 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11237 variants: function 11238 dispatch: 11239 CompositeExplicitAutograd: foreach_tensor_sigmoid_slow 11240 CUDA: foreach_tensor_sigmoid_cuda 11241 11242- func: _foreach_sigmoid_(Tensor(a!)[] self) -> () 11243 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11244 variants: function 11245 dispatch: 11246 CompositeExplicitAutograd: foreach_tensor_sigmoid_slow_ 11247 CUDA: foreach_tensor_sigmoid_cuda_ 11248 autogen: _foreach_sigmoid.out 11249 11250- func: _foreach_sign(Tensor[] self) -> Tensor[] 11251 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11252 variants: function 11253 dispatch: 11254 CompositeExplicitAutograd: foreach_tensor_sign_slow 11255 CUDA: foreach_tensor_sign_cuda 11256 11257- func: _foreach_sign_(Tensor(a!)[] self) -> () 11258 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11259 variants: function 11260 dispatch: 11261 CompositeExplicitAutograd: foreach_tensor_sign_slow_ 11262 CUDA: foreach_tensor_sign_cuda_ 11263 autogen: _foreach_sign.out 11264 11265- func: _foreach_sin(Tensor[] self) -> Tensor[] 11266 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11267 variants: function 11268 dispatch: 11269 CompositeExplicitAutograd: foreach_tensor_sin_slow 11270 CUDA: foreach_tensor_sin_cuda 11271 11272- func: _foreach_sin_(Tensor(a!)[] self) -> () 11273 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11274 variants: function 11275 dispatch: 11276 CompositeExplicitAutograd: foreach_tensor_sin_slow_ 11277 CUDA: foreach_tensor_sin_cuda_ 11278 autogen: _foreach_sin.out 11279 11280- func: _foreach_sinh(Tensor[] self) -> Tensor[] 11281 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11282 variants: function 11283 dispatch: 11284 CompositeExplicitAutograd: foreach_tensor_sinh_slow 11285 CUDA: foreach_tensor_sinh_cuda 11286 11287- func: _foreach_sinh_(Tensor(a!)[] self) -> () 11288 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11289 variants: function 11290 dispatch: 11291 CompositeExplicitAutograd: foreach_tensor_sinh_slow_ 11292 CUDA: foreach_tensor_sinh_cuda_ 11293 autogen: _foreach_sinh.out 11294 11295- func: _foreach_sqrt(Tensor[] self) -> Tensor[] 11296 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11297 variants: function 11298 dispatch: 11299 CompositeExplicitAutograd: foreach_tensor_sqrt_slow 11300 CUDA: foreach_tensor_sqrt_cuda 11301 11302- func: _foreach_sqrt_(Tensor(a!)[] self) -> () 11303 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11304 variants: function 11305 dispatch: 11306 CompositeExplicitAutograd: foreach_tensor_sqrt_slow_ 11307 CUDA: foreach_tensor_sqrt_cuda_ 11308 autogen: _foreach_sqrt.out 11309 11310- func: _foreach_tan(Tensor[] self) -> Tensor[] 11311 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11312 variants: function 11313 dispatch: 11314 CompositeExplicitAutograd: foreach_tensor_tan_slow 11315 CUDA: foreach_tensor_tan_cuda 11316 11317- func: _foreach_tan_(Tensor(a!)[] self) -> () 11318 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11319 variants: function 11320 dispatch: 11321 CompositeExplicitAutograd: foreach_tensor_tan_slow_ 11322 CUDA: foreach_tensor_tan_cuda_ 11323 autogen: _foreach_tan.out 11324 11325- func: _foreach_tanh(Tensor[] self) -> Tensor[] 11326 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11327 variants: function 11328 dispatch: 11329 CompositeExplicitAutograd: foreach_tensor_tanh_slow 11330 CUDA: foreach_tensor_tanh_cuda 11331 11332- func: _foreach_tanh_(Tensor(a!)[] self) -> () 11333 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11334 variants: function 11335 dispatch: 11336 CompositeExplicitAutograd: foreach_tensor_tanh_slow_ 11337 CUDA: foreach_tensor_tanh_cuda_ 11338 autogen: _foreach_tanh.out 11339 11340- func: _foreach_trunc(Tensor[] self) -> Tensor[] 11341 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11342 variants: function 11343 dispatch: 11344 CompositeExplicitAutograd: foreach_tensor_trunc_slow 11345 CUDA: foreach_tensor_trunc_cuda 11346 11347- func: _foreach_trunc_(Tensor(a!)[] self) -> () 11348 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11349 variants: function 11350 dispatch: 11351 CompositeExplicitAutograd: foreach_tensor_trunc_slow_ 11352 CUDA: foreach_tensor_trunc_cuda_ 11353 autogen: _foreach_trunc.out 11354 11355- func: _foreach_zero_(Tensor(a!)[] self) -> () 11356 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11357 variants: function 11358 dispatch: 11359 CompositeExplicitAutograd: foreach_tensor_zero_slow_ 11360 CUDA: foreach_tensor_zero_cuda_ 11361 autogen: _foreach_zero, _foreach_zero.out 11362 11363- func: _foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> () 11364 device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices 11365 variants: function 11366 dispatch: 11367 CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_ 11368 CUDA: foreach_tensor_copy_list_kernel_cuda_ 11369 autogen: _foreach_copy.out 11370 11371- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out 11372 device_check: NoCheck 11373 variants: function 11374 dispatch: 11375 CompositeExplicitAutograd: _foreach_copy 11376 11377- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor 11378 dispatch: 11379 CPU: bucketize_cpu 11380 CUDA: bucketize_cuda 11381 MPS: bucketize_mps 11382 11383- func: bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) 11384 dispatch: 11385 CPU: bucketize_out_cpu 11386 CUDA: bucketize_out_cuda 11387 MPS: bucketize_out_mps 11388 11389- func: bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor 11390 dispatch: 11391 CPU: bucketize_cpu 11392 CUDA: bucketize_cuda 11393 MPS: bucketize_mps 11394 autogen: bucketize.Scalar_out 11395 11396- func: searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor 11397 dispatch: 11398 CPU: searchsorted_cpu 11399 CUDA: searchsorted_cuda 11400 MPS: searchsorted_mps 11401 11402- func: searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) 11403 dispatch: 11404 CPU: searchsorted_out_cpu 11405 CUDA: searchsorted_out_cuda 11406 MPS: searchsorted_out_mps 11407 11408- func: searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor 11409 dispatch: 11410 CPU: searchsorted_cpu 11411 CUDA: searchsorted_cuda 11412 MPS: searchsorted_mps 11413 11414- func: searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) 11415 dispatch: 11416 CPU: searchsorted_out_cpu 11417 CUDA: searchsorted_out_cuda 11418 MPS: searchsorted_out_mps 11419 11420- func: _convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor 11421 structured_delegate: _convert_indices_from_coo_to_csr.out 11422 11423- func: _convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!) 11424 structured: True 11425 dispatch: 11426 CPU: _convert_indices_from_coo_to_csr_structured_cpu 11427 CUDA: _convert_indices_from_coo_to_csr_structured_cuda 11428 11429- func: _convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor 11430 structured_delegate: _convert_indices_from_csr_to_coo.out 11431 11432- func: _convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!) 11433 structured: True 11434 dispatch: 11435 CPU: _convert_indices_from_csr_to_coo_structured_cpu 11436 CUDA: _convert_indices_from_csr_to_coo_structured_cuda 11437 11438## NN wrappers 11439 11440- func: mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) 11441 device_check: NoCheck # TensorIterator 11442 structured: True 11443 structured_inherits: TensorIteratorBase 11444 python_module: nn 11445 dispatch: 11446 CPU, CUDA: mse_loss_out 11447 MPS: mse_loss_out_mps 11448 11449- func: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor 11450 device_check: NoCheck # TensorIterator 11451 structured_delegate: mse_loss.out 11452 python_module: nn 11453 11454- func: mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) 11455 python_module: nn 11456 dispatch: 11457 CPU, CUDA: mse_loss_backward_out 11458 MPS: mse_loss_backward_out_mps 11459 11460- func: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor 11461 python_module: nn 11462 dispatch: 11463 CPU, CUDA: mse_loss_backward 11464 MPS: mse_loss_backward_mps 11465 11466- func: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor 11467 python_module: nn 11468 11469- func: multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) 11470 python_module: nn 11471 dispatch: 11472 CPU: multi_margin_loss_cpu_out 11473 CUDA: multi_margin_loss_cuda_out 11474 11475- func: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor 11476 python_module: nn 11477 dispatch: 11478 CPU: multi_margin_loss_cpu 11479 CUDA: multi_margin_loss_cuda 11480 11481- func: multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) 11482 python_module: nn 11483 dispatch: 11484 CPU: multi_margin_loss_cpu_backward_out 11485 CUDA: multi_margin_loss_cuda_backward_out 11486 11487- func: multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor 11488 python_module: nn 11489 dispatch: 11490 CPU: multi_margin_loss_cpu_backward 11491 CUDA: multi_margin_loss_cuda_backward 11492 11493- func: multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) 11494 python_module: nn 11495 11496- func: multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor 11497 python_module: nn 11498 11499- func: multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!)) 11500 python_module: nn 11501 dispatch: 11502 CPU: multilabel_margin_loss_forward_out_cpu 11503 CUDA: multilabel_margin_loss_forward_out_cuda 11504 11505- func: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) 11506 python_module: nn 11507 dispatch: 11508 CPU: multilabel_margin_loss_forward_cpu 11509 CUDA: multilabel_margin_loss_forward_cuda 11510 11511- func: multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!) 11512 python_module: nn 11513 dispatch: 11514 CPU: multilabel_margin_loss_backward_cpu_out 11515 CUDA: multilabel_margin_loss_backward_cuda_out 11516 11517- func: multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor 11518 python_module: nn 11519 dispatch: 11520 CPU: multilabel_margin_loss_backward_cpu 11521 CUDA: multilabel_margin_loss_backward_cuda 11522 11523- func: nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) 11524 python_module: nn 11525 11526- func: nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor 11527 python_module: nn 11528 dispatch: 11529 CompositeImplicitAutograd: nll_loss_nd_symint 11530 11531- func: nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor 11532 python_module: nn 11533 dispatch: 11534 CompositeImplicitAutograd: nll_loss_symint 11535 11536- func: nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) 11537 python_module: nn 11538 structured: True 11539 dispatch: 11540 CPU: nll_loss_forward_out_cpu 11541 CUDA: nll_loss_forward_out_cuda 11542 MPS: nll_loss_forward_out_mps 11543 11544- func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) 11545 python_module: nn 11546 structured_delegate: nll_loss_forward.output 11547 11548- func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) 11549 python_module: nn 11550 structured: True 11551 dispatch: 11552 CPU: nll_loss_backward_out_cpu 11553 CUDA: nll_loss_backward_out_cuda 11554 MPS: nll_loss_backward_out_mps 11555 11556- func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor 11557 python_module: nn 11558 structured_delegate: nll_loss_backward.grad_input 11559 11560- func: nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) 11561 python_module: nn 11562 11563- func: nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor 11564 python_module: nn 11565 dispatch: 11566 CompositeImplicitAutograd: nll_loss2d_symint 11567 11568- func: nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) 11569 python_module: nn 11570 dispatch: 11571 CPU: nll_loss2d_forward_out_cpu 11572 CUDA: nll_loss2d_forward_out_cuda 11573 MPS: nll_loss2d_forward_out_mps 11574 11575- func: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) 11576 python_module: nn 11577 dispatch: 11578 CPU: nll_loss2d_forward_cpu 11579 CUDA: nll_loss2d_forward_cuda 11580 MPS: nll_loss2d_forward_mps 11581 11582- func: nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) 11583 python_module: nn 11584 dispatch: 11585 CPU: nll_loss2d_backward_out_cpu 11586 CUDA: nll_loss2d_backward_out_cuda 11587 MPS: nll_loss2d_backward_out_mps 11588 11589- func: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor 11590 python_module: nn 11591 dispatch: 11592 CPU: nll_loss2d_backward_cpu 11593 CUDA: nll_loss2d_backward_cuda 11594 MPS: nll_loss2d_backward_mps 11595 11596- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) 11597 device_check: NoCheck # TensorIterator 11598 structured: True 11599 structured_inherits: TensorIteratorBase 11600 python_module: nn 11601 dispatch: 11602 CPU, CUDA: smooth_l1_loss_out 11603 MPS: smooth_l1_loss_out_mps 11604 11605- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor 11606 device_check: NoCheck # TensorIterator 11607 structured_delegate: smooth_l1_loss.out 11608 python_module: nn 11609 11610- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!) 11611 python_module: nn 11612 dispatch: 11613 CPU: smooth_l1_loss_backward_out 11614 CUDA: smooth_l1_loss_backward_out 11615 MPS: smooth_l1_loss_backward_out_mps 11616 11617- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor 11618 python_module: nn 11619 dispatch: 11620 CompositeExplicitAutograd: smooth_l1_loss_backward 11621 11622- func: huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!) 11623 python_module: nn 11624 dispatch: 11625 CPU, CUDA: huber_loss_out 11626 MPS: huber_loss_out_mps 11627 11628- func: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor 11629 python_module: nn 11630 dispatch: 11631 CPU, CUDA: huber_loss 11632 MPS: huber_loss_mps 11633 11634- func: huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!) 11635 python_module: nn 11636 dispatch: 11637 CPU, CUDA: huber_loss_backward_out 11638 MPS: huber_loss_backward_out_mps 11639 11640- func: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor 11641 python_module: nn 11642 dispatch: 11643 CompositeExplicitAutograd: huber_loss_backward 11644 11645- func: soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) 11646 python_module: nn 11647 dispatch: 11648 CompositeExplicitAutograd: soft_margin_loss_out 11649 11650- func: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor 11651 python_module: nn 11652 dispatch: 11653 CompositeExplicitAutograd: soft_margin_loss 11654 11655- func: soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) 11656 python_module: nn 11657 dispatch: 11658 CompositeExplicitAutograd: soft_margin_loss_backward_out 11659 11660- func: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor 11661 python_module: nn 11662 dispatch: 11663 CompositeExplicitAutograd: soft_margin_loss_backward 11664 11665- func: elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!) 11666 structured: True 11667 structured_inherits: TensorIteratorBase 11668 device_check: NoCheck # TensorIterator 11669 python_module: nn 11670 dispatch: 11671 CPU, CUDA: elu_out 11672 MPS: elu_out_mps 11673 11674- func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor 11675 structured_delegate: elu.out 11676 device_check: NoCheck # TensorIterator 11677 python_module: nn 11678 11679- func: elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!) 11680 structured: True 11681 structured_inherits: TensorIteratorBase 11682 python_module: nn 11683 dispatch: 11684 CPU, CUDA: elu_backward_out 11685 MPS: elu_backward_out_mps 11686 11687- func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor 11688 structured_delegate: elu_backward.grad_input 11689 python_module: nn 11690 11691- func: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) 11692 structured_delegate: elu.out 11693 device_check: NoCheck # TensorIterator 11694 python_module: nn 11695 11696- func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) 11697 structured: True 11698 structured_inherits: TensorIteratorBase 11699 python_module: nn 11700 dispatch: 11701 CPU, CUDA: glu_out 11702 MPS: glu_out_mps 11703 11704- func: glu(Tensor self, int dim=-1) -> Tensor 11705 structured_delegate: glu.out 11706 device_check: NoCheck # TensorIterator 11707 python_module: nn 11708 11709- func: glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) 11710 python_module: nn 11711 dispatch: 11712 CPU: glu_backward_cpu_out 11713 CUDA: glu_backward_cuda_out 11714 MPS: glu_backward_mps_out 11715 11716- func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor 11717 python_module: nn 11718 dispatch: 11719 CPU: glu_backward_cpu 11720 CUDA: glu_backward_cuda 11721 MPS: glu_backward_mps 11722 11723- func: glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor 11724 python_module: nn 11725 dispatch: 11726 CPU, CUDA: glu_jvp 11727 autogen: glu_jvp.out 11728 11729- func: glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor 11730 python_module: nn 11731 dispatch: 11732 CPU, CUDA: glu_backward_jvp 11733 autogen: glu_backward_jvp.out 11734 11735- func: hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 11736 structured: True 11737 structured_inherits: TensorIteratorBase 11738 device_check: NoCheck # TensorIterator 11739 python_module: nn 11740 dispatch: 11741 CPU, CUDA: hardsigmoid_out 11742 MPS: hardsigmoid_out_mps 11743 QuantizedCPU: hardsigmoid_out_quantized_cpu 11744 11745- func: hardsigmoid(Tensor self) -> Tensor 11746 structured_delegate: hardsigmoid.out 11747 device_check: NoCheck # TensorIterator 11748 python_module: nn 11749 dispatch: 11750 QuantizedCPU: hardsigmoid_quantized_cpu 11751 11752- func: hardsigmoid_(Tensor(a!) self) -> Tensor(a!) 11753 structured_delegate: hardsigmoid.out 11754 device_check: NoCheck # TensorIterator 11755 python_module: nn 11756 11757- func: hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) 11758 structured: True 11759 structured_inherits: TensorIteratorBase 11760 python_module: nn 11761 dispatch: 11762 CPU, CUDA: hardsigmoid_backward_out 11763 MPS: hardsigmoid_backward_out_mps 11764 11765- func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor 11766 structured_delegate: hardsigmoid_backward.grad_input 11767 python_module: nn 11768 11769- func: hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) 11770 device_check: NoCheck # TensorIterator 11771 python_module: nn 11772 dispatch: 11773 CPU, CUDA, MPS: hardtanh_out 11774 QuantizedCPU: hardtanh_out_quantized_cpu 11775 11776- func: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor 11777 device_check: NoCheck # TensorIterator 11778 python_module: nn 11779 dispatch: 11780 CPU, CUDA, MPS: hardtanh 11781 QuantizedCPU: hardtanh_quantized_cpu 11782 tags: core 11783 11784- func: hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) 11785 python_module: nn 11786 dispatch: 11787 CPU, CUDA: hardtanh_backward_out 11788 MPS: hardtanh_backward_out_mps 11789 11790- func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor 11791 python_module: nn 11792 dispatch: 11793 CPU, CUDA: hardtanh_backward 11794 MPS: hardtanh_backward_mps 11795 11796- func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) 11797 device_check: NoCheck # TensorIterator 11798 python_module: nn 11799 dispatch: 11800 CPU, CUDA, MPS: hardtanh_ 11801 QuantizedCPU: hardtanh_quantized_cpu_ 11802 11803- func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 11804 device_check: NoCheck # TensorIterator 11805 python_module: nn 11806 dispatch: 11807 CPU, CUDA: hardswish_out 11808 MPS: hardswish_out_mps 11809 11810- func: hardswish(Tensor self) -> Tensor 11811 device_check: NoCheck # TensorIterator 11812 python_module: nn 11813 dispatch: 11814 CPU, CUDA: hardswish 11815 MPS: hardswish_mps 11816 11817- func: hardswish_(Tensor(a!) self) -> Tensor(a!) 11818 device_check: NoCheck # TensorIterator 11819 python_module: nn 11820 dispatch: 11821 CPU, CUDA: hardswish_ 11822 MPS: hardswish_mps_ 11823 11824- func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor 11825 python_module: nn 11826 dispatch: 11827 CPU, CUDA: hardswish_backward 11828 MPS: hardswish_backward_mps 11829 autogen: hardswish_backward.out 11830 11831- func: leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) 11832 structured: True 11833 structured_inherits: TensorIteratorBase 11834 device_check: NoCheck # TensorIterator 11835 python_module: nn 11836 dispatch: 11837 CPU, CUDA: leaky_relu_out 11838 MPS: leaky_relu_out_mps 11839 QuantizedCPU: leaky_relu_out_quantized_cpu 11840 11841- func: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor 11842 structured_delegate: leaky_relu.out 11843 device_check: NoCheck # TensorIterator 11844 python_module: nn 11845 dispatch: 11846 QuantizedCPU: leaky_relu_quantized_cpu 11847 tags: core 11848 11849- func: leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) 11850 structured: True 11851 structured_inherits: TensorIteratorBase 11852 python_module: nn 11853 dispatch: 11854 CPU, CUDA: leaky_relu_backward_out 11855 MPS: leaky_relu_backward_out_mps 11856 11857- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor 11858 structured_delegate: leaky_relu_backward.grad_input 11859 python_module: nn 11860 11861- func: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) 11862 structured_delegate: leaky_relu.out 11863 device_check: NoCheck # TensorIterator 11864 python_module: nn 11865 dispatch: 11866 QuantizedCPU: leaky_relu_quantized_cpu_ 11867 11868- func: log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 11869 device_check: NoCheck # TensorIterator 11870 python_module: nn 11871 11872- func: log_sigmoid(Tensor self) -> Tensor 11873 device_check: NoCheck # TensorIterator 11874 python_module: nn 11875 11876- func: log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!)) 11877 device_check: NoCheck # TensorIterator 11878 python_module: nn 11879 dispatch: 11880 CPU: log_sigmoid_forward_out_cpu 11881 CUDA: log_sigmoid_forward_out_cuda 11882 MPS: log_sigmoid_forward_out_mps 11883 11884- func: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) 11885 device_check: NoCheck # TensorIterator 11886 python_module: nn 11887 dispatch: 11888 CPU: log_sigmoid_forward_cpu 11889 CUDA: log_sigmoid_forward_cuda 11890 MPS: log_sigmoid_forward_mps 11891 11892- func: log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) 11893 python_module: nn 11894 dispatch: 11895 CPU: log_sigmoid_backward_cpu_out 11896 CUDA: log_sigmoid_backward_cuda_out 11897 MPS: log_sigmoid_backward_mps_out 11898 11899- func: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor 11900 python_module: nn 11901 dispatch: 11902 CPU: log_sigmoid_backward_cpu 11903 CUDA: log_sigmoid_backward_cuda 11904 MPS: log_sigmoid_backward_mps 11905 11906- func: rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) 11907 python_module: nn 11908 tags: nondeterministic_seeded 11909 dispatch: 11910 CPU: rrelu_with_noise_out_cpu 11911 CUDA: rrelu_with_noise_out_cuda 11912 11913- func: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor 11914 python_module: nn 11915 dispatch: 11916 CPU: rrelu_with_noise_cpu 11917 CUDA: rrelu_with_noise_cuda 11918 tags: nondeterministic_seeded 11919 11920- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor 11921 python_module: nn 11922 dispatch: 11923 CompositeExplicitAutograd: rrelu_with_noise_backward 11924 autogen: rrelu_with_noise_backward.out 11925 11926- func: rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) 11927 python_module: nn 11928 tags: nondeterministic_seeded 11929 dispatch: 11930 CPU: rrelu_with_noise_cpu_ 11931 CUDA: rrelu_with_noise_cuda_ 11932 11933- func: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) 11934 structured: True 11935 structured_inherits: TensorIteratorBase 11936 device_check: NoCheck # TensorIterator 11937 python_module: nn 11938 dispatch: 11939 CPU, CUDA: softplus_out 11940 MPS: softplus_out_mps 11941 11942- func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor 11943 structured_delegate: softplus.out 11944 device_check: NoCheck # TensorIterator 11945 python_module: nn 11946 11947- func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) 11948 structured: True 11949 structured_inherits: TensorIteratorBase 11950 python_module: nn 11951 dispatch: 11952 CPU, CUDA: softplus_backward_out 11953 MPS: softplus_backward_out_mps 11954 11955- func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor 11956 structured_delegate: softplus_backward.grad_input 11957 python_module: nn 11958 11959- func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) 11960 structured: True 11961 structured_inherits: TensorIteratorBase 11962 device_check: NoCheck # TensorIterator 11963 python_module: nn 11964 dispatch: 11965 CPU, CUDA: softshrink_out 11966 MPS: softshrink_out_mps 11967 11968- func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor 11969 structured_delegate: softshrink.out 11970 device_check: NoCheck # TensorIterator 11971 python_module: nn 11972 11973- func: softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) 11974 structured: True 11975 structured_inherits: TensorIteratorBase 11976 python_module: nn 11977 dispatch: 11978 CPU, CUDA: softshrink_backward_out 11979 MPS: softshrink_backward_out_mps 11980 11981- func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor 11982 structured_delegate: softshrink_backward.grad_input 11983 python_module: nn 11984 11985- func: adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) 11986 python_module: nn 11987 dispatch: 11988 CPU: adaptive_avg_pool2d_out_cpu 11989 CUDA: adaptive_avg_pool2d_out_cuda 11990 MPS: adaptive_avg_pool2d_out_mps 11991 MkldnnCPU: mkldnn_adaptive_avg_pool2d_out_stub 11992 11993- func: adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor 11994 python_module: nn 11995 dispatch: 11996 CompositeImplicitAutograd: adaptive_avg_pool2d_symint 11997 11998- func: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor 11999 dispatch: 12000 MkldnnCPU: mkldnn_adaptive_avg_pool2d 12001 12002- func: mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) 12003 dispatch: 12004 MkldnnCPU: mkldnn_adaptive_avg_pool2d_out 12005 12006- func: mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor 12007 dispatch: 12008 MkldnnCPU: mkldnn_adaptive_avg_pool2d_backward 12009 autogen: mkldnn_adaptive_avg_pool2d_backward.out 12010 12011- func: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor 12012 dispatch: 12013 CPU: adaptive_avg_pool2d_cpu 12014 CUDA: adaptive_avg_pool2d_cuda 12015 MPS: adaptive_avg_pool2d_mps 12016 QuantizedCPU: adaptive_avg_pool2d_quantized_cpu 12017 QuantizedCUDA: adaptive_avg_pool2d_quantized_cuda 12018 autogen: _adaptive_avg_pool2d.out 12019 tags: core 12020 12021- func: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor 12022 python_module: nn 12023 dispatch: 12024 CPU: adaptive_avg_pool2d_backward_cpu 12025 CUDA: adaptive_avg_pool2d_backward_cuda 12026 MPS: adaptive_avg_pool2d_backward_mps 12027 autogen: _adaptive_avg_pool2d_backward.out 12028 tags: core 12029 12030- func: adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) 12031 python_module: nn 12032 dispatch: 12033 CPU: adaptive_avg_pool3d_out_cpu 12034 CUDA: adaptive_avg_pool3d_out_cuda 12035 QuantizedCPU: adaptive_avg_pool3d_out_quantized_cpu 12036 12037- func: adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor 12038 python_module: nn 12039 dispatch: 12040 CompositeImplicitAutograd: adaptive_avg_pool3d_symint 12041 12042- func: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor 12043 dispatch: 12044 CPU: adaptive_avg_pool3d_cpu 12045 CUDA: adaptive_avg_pool3d_cuda 12046 QuantizedCPU: adaptive_avg_pool3d_quantized_cpu 12047 autogen: _adaptive_avg_pool3d.out 12048 tags: core 12049 12050- func: adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) 12051 python_module: nn 12052 dispatch: 12053 CPU: adaptive_avg_pool3d_backward_out_cpu 12054 CUDA: adaptive_avg_pool3d_backward_out_cuda 12055 12056- func: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor 12057 python_module: nn 12058 dispatch: 12059 CPU: adaptive_avg_pool3d_backward_cpu 12060 CUDA: adaptive_avg_pool3d_backward_cuda 12061 autogen: _adaptive_avg_pool3d_backward.out 12062 12063# Return: (Tensor output, Tensor indices) 12064- func: adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) 12065 python_module: nn 12066 structured: True 12067 dispatch: 12068 CPU: adaptive_max_pool2d_out_cpu 12069 CUDA: adaptive_max_pool2d_out_cuda 12070 MPS: adaptive_max_pool2d_out_mps 12071 12072# Return: (Tensor output, Tensor indices) 12073- func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) 12074 python_module: nn 12075 structured_delegate: adaptive_max_pool2d.out 12076 12077- func: adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) 12078 python_module: nn 12079 structured: True 12080 dispatch: 12081 CPU: adaptive_max_pool2d_backward_out_cpu 12082 CUDA: adaptive_max_pool2d_backward_out_cuda 12083 MPS: adaptive_max_pool2d_backward_out_mps 12084 12085- func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor 12086 python_module: nn 12087 structured_delegate: adaptive_max_pool2d_backward.grad_input 12088 12089# Return: (Tensor output, Tensor indices) 12090- func: adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) 12091 python_module: nn 12092 structured: True 12093 dispatch: 12094 CPU: adaptive_max_pool3d_out_cpu 12095 CUDA: adaptive_max_pool3d_out_cuda 12096 12097# Return: (Tensor output, Tensor indices) 12098- func: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) 12099 python_module: nn 12100 structured_delegate: adaptive_max_pool3d.out 12101 12102- func: adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) 12103 python_module: nn 12104 structured: True 12105 dispatch: 12106 CPU: adaptive_max_pool3d_backward_out_cpu 12107 CUDA: adaptive_max_pool3d_backward_out_cuda 12108 12109- func: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor 12110 python_module: nn 12111 structured_delegate: adaptive_max_pool3d_backward.grad_input 12112 12113- func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) 12114 python_module: nn 12115 structured: True 12116 precomputed: 12117 - kernel_size -> int kH, int kW 12118 - stride -> int dH, int dW 12119 - padding -> int padH, int padW 12120 dispatch: 12121 CPU: avg_pool2d_out_cpu 12122 CUDA: avg_pool2d_out_cuda 12123 MPS: avg_pool2d_out_mps 12124 MkldnnCPU: mkldnn_avg_pool2d_out 12125 12126- func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor 12127 python_module: nn 12128 structured_delegate: avg_pool2d.out 12129 dispatch: 12130 MkldnnCPU: mkldnn_avg_pool2d 12131 QuantizedCPU: avg_pool2d_quantized_cpu 12132 tags: core 12133 12134- func: avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) 12135 python_module: nn 12136 structured: True 12137 dispatch: 12138 CPU: avg_pool2d_backward_out_cpu 12139 CUDA: avg_pool2d_backward_out_cuda 12140 MPS: avg_pool2d_backward_out_mps 12141 MkldnnCPU: mkldnn_avg_pool2d_backward_out 12142 12143- func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor 12144 python_module: nn 12145 structured_delegate: avg_pool2d_backward.grad_input 12146 dispatch: 12147 MkldnnCPU: mkldnn_avg_pool2d_backward 12148 tags: core 12149 12150- func: avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) 12151 python_module: nn 12152 structured: True 12153 dispatch: 12154 CPU: avg_pool3d_out_cpu 12155 CUDA: avg_pool3d_out_cuda 12156 MkldnnCPU: mkldnn_avg_pool3d_out 12157 12158- func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor 12159 python_module: nn 12160 structured_delegate: avg_pool3d.out 12161 dispatch: 12162 MkldnnCPU: mkldnn_avg_pool3d 12163 QuantizedCPU: avg_pool3d_quantized_cpu 12164 tags: core 12165 12166- func: avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) 12167 python_module: nn 12168 structured: True 12169 dispatch: 12170 CPU: avg_pool3d_backward_out_cpu 12171 CUDA: avg_pool3d_backward_out_cuda 12172 MkldnnCPU: mkldnn_avg_pool3d_backward_out 12173 12174- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor 12175 python_module: nn 12176 structured_delegate: avg_pool3d_backward.grad_input 12177 dispatch: 12178 MkldnnCPU: mkldnn_avg_pool3d_backward 12179 12180# Return: (Tensor output, Tensor indices) 12181- func: fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) 12182 python_module: nn 12183 structured: True 12184 dispatch: 12185 CPU: fractional_max_pool2d_out_cpu 12186 CUDA: fractional_max_pool2d_out_cuda 12187 12188# Return: (Tensor output, Tensor indices) 12189- func: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) 12190 python_module: nn 12191 structured_delegate: fractional_max_pool2d.output 12192 12193- func: fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) 12194 python_module: nn 12195 structured: True 12196 dispatch: 12197 CPU: fractional_max_pool2d_backward_cpu 12198 CUDA: fractional_max_pool2d_backward_cuda 12199 12200- func: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor 12201 python_module: nn 12202 structured_delegate: fractional_max_pool2d_backward.grad_input 12203 12204# Return: (Tensor output, Tensor indices) 12205- func: fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) 12206 python_module: nn 12207 structured: True 12208 precomputed: 12209 - kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW 12210 - output_size -> int outputT, int outputH, int outputW 12211 - int numBatch, int numPlanes, int inputT, int inputH, int inputW 12212 dispatch: 12213 CPU: fractional_max_pool3d_out_cpu 12214 CUDA: fractional_max_pool3d_out_cuda 12215 12216# Return: (Tensor output, Tensor indices) 12217- func: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) 12218 python_module: nn 12219 structured_delegate: fractional_max_pool3d.output 12220 12221- func: fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) 12222 python_module: nn 12223 dispatch: 12224 CPU: fractional_max_pool3d_backward_out_cpu 12225 CUDA: fractional_max_pool3d_backward_out_cuda 12226 12227- func: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor 12228 python_module: nn 12229 dispatch: 12230 CPU: fractional_max_pool3d_backward_cpu 12231 CUDA: fractional_max_pool3d_backward_cuda 12232 12233# Return: (Tensor output, Tensor indices) 12234- func: max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) 12235 python_module: nn 12236 structured: True 12237 dispatch: 12238 CPU: max_pool2d_with_indices_out_cpu 12239 CUDA: max_pool2d_with_indices_out_cuda 12240 MPS: max_pool2d_with_indices_out_mps 12241 12242# Return: (Tensor output, Tensor indices) 12243- func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) 12244 python_module: nn 12245 structured_delegate: max_pool2d_with_indices.out 12246 tags: core 12247 12248- func: max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) 12249 python_module: nn 12250 structured: True 12251 dispatch: 12252 CPU: max_pool2d_with_indices_backward_out_cpu 12253 CUDA: max_pool2d_with_indices_backward_out_cuda 12254 MPS: max_pool2d_with_indices_backward_out_mps 12255 12256- func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor 12257 python_module: nn 12258 structured_delegate: max_pool2d_with_indices_backward.grad_input 12259 tags: core 12260 12261# Return: (Tensor output, Tensor indices) 12262- func: max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) 12263 python_module: nn 12264 dispatch: 12265 CPU: max_pool3d_with_indices_out_cpu 12266 CUDA: max_pool3d_with_indices_out_cuda 12267 12268# Return: (Tensor output, Tensor indices) 12269- func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) 12270 python_module: nn 12271 dispatch: 12272 CPU: max_pool3d_with_indices_cpu 12273 CUDA: max_pool3d_with_indices_cuda 12274 tags: core 12275 12276- func: max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) 12277 python_module: nn 12278 dispatch: 12279 CPU: max_pool3d_with_indices_backward_out_cpu 12280 CUDA: max_pool3d_with_indices_backward_out_cuda 12281 12282- func: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor 12283 python_module: nn 12284 dispatch: 12285 CPU: max_pool3d_with_indices_backward_cpu 12286 CUDA: max_pool3d_with_indices_backward_cuda 12287 12288- func: max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) 12289 python_module: nn 12290 dispatch: 12291 CPU: max_unpooling2d_forward_out_cpu 12292 CUDA: max_unpooling2d_forward_out_cuda 12293 12294- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor 12295 python_module: nn 12296 dispatch: 12297 CPU: max_unpooling2d_forward_cpu 12298 CUDA: max_unpooling2d_forward_cuda 12299 12300- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) 12301 python_module: nn 12302 dispatch: 12303 CPU: max_unpooling3d_forward_out_cpu 12304 CUDA: max_unpooling3d_forward_out_cuda 12305 12306- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor 12307 python_module: nn 12308 dispatch: 12309 CPU: max_unpooling3d_forward_cpu 12310 CUDA: max_unpooling3d_forward_cuda 12311 12312- func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) 12313 python_module: nn 12314 structured: True 12315 dispatch: 12316 CPU: reflection_pad1d_out_cpu 12317 QuantizedCPU: reflection_pad1d_out_quantized_cpu 12318 CUDA: reflection_pad1d_out_cuda 12319 MPS: reflection_pad1d_out_mps 12320 12321- func: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor 12322 python_module: nn 12323 structured_delegate: reflection_pad1d.out 12324 tags: core 12325 12326- func: reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) 12327 python_module: nn 12328 structured: True 12329 dispatch: 12330 CPU: reflection_pad1d_backward_out_cpu 12331 CUDA: reflection_pad1d_backward_out_cuda 12332 MPS: reflection_pad1d_backward_out_mps 12333 12334- func: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor 12335 python_module: nn 12336 structured_delegate: reflection_pad1d_backward.grad_input 12337 12338- func: reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) 12339 python_module: nn 12340 dispatch: 12341 CPU, QuantizedCPU: reflection_pad2d_out_cpu 12342 CUDA: reflection_pad2d_out_cuda 12343 MPS: reflection_pad2d_out_mps 12344 12345- func: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor 12346 python_module: nn 12347 dispatch: 12348 CPU: reflection_pad2d_cpu 12349 QuantizedCPU: reflection_pad2d_quantized_cpu 12350 CUDA: reflection_pad2d_cuda 12351 MPS: reflection_pad2d_mps 12352 tags: core 12353 12354- func: reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) 12355 python_module: nn 12356 dispatch: 12357 CPU: reflection_pad2d_backward_out_cpu 12358 CUDA: reflection_pad2d_backward_out_cuda 12359 MPS: reflection_pad2d_backward_out_mps 12360 12361- func: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor 12362 python_module: nn 12363 dispatch: 12364 CPU: reflection_pad2d_backward_cpu 12365 CUDA: reflection_pad2d_backward_cuda 12366 MPS: reflection_pad2d_backward_mps 12367 12368- func: reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) 12369 python_module: nn 12370 structured: True 12371 dispatch: 12372 CPU: reflection_pad3d_out_cpu 12373 CUDA: reflection_pad3d_out_cuda 12374 MPS: reflection_pad3d_out_mps 12375 12376- func: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor 12377 python_module: nn 12378 structured_delegate: reflection_pad3d.out 12379 tags: core 12380 12381- func: reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) 12382 python_module: nn 12383 structured: True 12384 dispatch: 12385 CPU: reflection_pad3d_backward_out_cpu 12386 CUDA: reflection_pad3d_backward_out_cuda 12387 MPS: reflection_pad3d_backward_out_mps 12388 12389- func: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor 12390 python_module: nn 12391 structured_delegate: reflection_pad3d_backward.grad_input 12392 12393- func: replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) 12394 python_module: nn 12395 structured: True 12396 dispatch: 12397 CPU: replication_pad1d_out_cpu 12398 CUDA: replication_pad1d_out_cuda 12399 MPS: replication_pad1d_out_mps 12400 12401- func: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor 12402 python_module: nn 12403 structured_delegate: replication_pad1d.out 12404 12405- func: replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) 12406 python_module: nn 12407 structured: True 12408 dispatch: 12409 CPU: replication_pad1d_backward_out_cpu 12410 CUDA: replication_pad1d_backward_out_cuda 12411 MPS: replication_pad1d_backward_out_mps 12412 12413- func: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor 12414 python_module: nn 12415 structured_delegate: replication_pad1d_backward.grad_input 12416 12417- func: replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) 12418 python_module: nn 12419 structured: True 12420 dispatch: 12421 CPU: replication_pad2d_out_cpu 12422 CUDA: replication_pad2d_out_cuda 12423 MPS: replication_pad2d_out_mps 12424 12425- func: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor 12426 python_module: nn 12427 structured_delegate: replication_pad2d.out 12428 tags: core 12429 12430- func: replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) 12431 python_module: nn 12432 dispatch: 12433 CPU: replication_pad2d_backward_out_cpu 12434 CUDA: replication_pad2d_backward_out_cuda 12435 MPS: replication_pad2d_backward_out_mps 12436 12437- func: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor 12438 python_module: nn 12439 dispatch: 12440 CPU: replication_pad2d_backward_cpu 12441 CUDA: replication_pad2d_backward_cuda 12442 MPS: replication_pad2d_backward_mps 12443 12444- func: replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) 12445 python_module: nn 12446 structured: True 12447 dispatch: 12448 CPU: replication_pad3d_out_cpu 12449 CUDA: replication_pad3d_out_cuda 12450 MPS: replication_pad3d_out_mps 12451 12452- func: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor 12453 python_module: nn 12454 structured_delegate: replication_pad3d.out 12455 tags: core 12456 12457 12458- func: replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) 12459 python_module: nn 12460 dispatch: 12461 CPU: replication_pad3d_backward_out_cpu 12462 CUDA: replication_pad3d_backward_out_cuda 12463 MPS: replication_pad3d_backward_out_mps 12464 12465- func: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor 12466 python_module: nn 12467 dispatch: 12468 CPU: replication_pad3d_backward_cpu 12469 CUDA: replication_pad3d_backward_cuda 12470 MPS: replication_pad3d_backward_mps 12471 12472- func: _pad_circular(Tensor self, SymInt[] pad) -> Tensor 12473 python_module: nn 12474 dispatch: 12475 CompositeImplicitAutograd: _pad_circular_symint 12476 12477- func: _pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor 12478 python_module: nn 12479 dispatch: 12480 CompositeImplicitAutograd: _pad_enum_symint 12481 12482- func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor 12483 python_module: nn 12484 dispatch: 12485 CompositeImplicitAutograd: pad_symint 12486 12487- func: upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor 12488 python_module: nn 12489 autogen: upsample_linear1d.vec_out 12490 12491- func: upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor 12492 python_module: nn 12493 autogen: upsample_bilinear2d.vec_out 12494 tags: core 12495 12496- func: _upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor 12497 python_module: nn 12498 autogen: _upsample_bilinear2d_aa.vec_out 12499 12500- func: upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor 12501 python_module: nn 12502 autogen: upsample_trilinear3d.vec_out 12503 12504- func: upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor 12505 python_module: nn 12506 autogen: upsample_bicubic2d.vec_out 12507 12508- func: _upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor 12509 python_module: nn 12510 autogen: _upsample_bicubic2d_aa.vec_out 12511 12512- func: upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor 12513 python_module: nn 12514 autogen: upsample_nearest1d.vec_out 12515 12516- func: _upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor 12517 python_module: nn 12518 autogen: _upsample_nearest_exact1d.vec_out 12519 12520- func: upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor 12521 python_module: nn 12522 autogen: upsample_nearest2d.vec_out 12523 tags: core 12524 12525- func: _upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor 12526 python_module: nn 12527 autogen: _upsample_nearest_exact2d.vec_out 12528 12529- func: upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor 12530 python_module: nn 12531 autogen: upsample_nearest3d.vec_out 12532 12533- func: _upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor 12534 python_module: nn 12535 autogen: _upsample_nearest_exact3d.vec_out 12536 12537# NOTE: all of the non-"vec" upsample overloads are only kept for backward compatibility. 12538- func: upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) 12539 python_module: nn 12540 structured: True 12541 dispatch: 12542 CPU: upsample_linear1d_out_cpu 12543 CUDA: upsample_linear1d_out_cuda 12544 MPS: upsample_linear1d_out_mps 12545 12546- func: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor 12547 python_module: nn 12548 structured_delegate: upsample_linear1d.out 12549 12550- func: upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12551 python_module: nn 12552 structured: True 12553 dispatch: 12554 CPU: upsample_linear1d_backward_out_cpu 12555 CUDA: upsample_linear1d_backward_out_cuda 12556 MPS: upsample_linear1d_backward_out_mps 12557 12558- func: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor 12559 python_module: nn 12560 structured_delegate: upsample_linear1d_backward.grad_input 12561 12562- func: upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12563 python_module: nn 12564 structured: True 12565 dispatch: 12566 CPU: upsample_bilinear2d_out_cpu 12567 CUDA: upsample_bilinear2d_out_cuda 12568 MPS: upsample_bilinear2d_out_mps 12569 12570- func: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor 12571 python_module: nn 12572 structured_delegate: upsample_bilinear2d.out 12573 dispatch: 12574 QuantizedCPU: upsample_bilinear2d_quantized_cpu 12575 12576- func: upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12577 python_module: nn 12578 structured: True 12579 dispatch: 12580 CPU: upsample_bilinear2d_backward_out_cpu 12581 CUDA: upsample_bilinear2d_backward_out_cuda 12582 MPS: upsample_bilinear2d_backward_out_mps 12583 12584- func: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor 12585 python_module: nn 12586 structured_delegate: upsample_bilinear2d_backward.grad_input 12587 12588- func: _upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12589 python_module: nn 12590 structured: True 12591 dispatch: 12592 CPU: _upsample_bilinear2d_aa_out_cpu 12593 CUDA: _upsample_bilinear2d_aa_out_cuda 12594 12595- func: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor 12596 python_module: nn 12597 structured_delegate: _upsample_bilinear2d_aa.out 12598 12599- func: _upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12600 python_module: nn 12601 structured: True 12602 dispatch: 12603 CPU: _upsample_bilinear2d_aa_backward_out_cpu 12604 CUDA: _upsample_bilinear2d_aa_backward_out_cuda 12605 12606- func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor 12607 python_module: nn 12608 structured_delegate: _upsample_bilinear2d_aa_backward.grad_input 12609 12610- func: upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12611 python_module: nn 12612 structured: True 12613 dispatch: 12614 CPU: upsample_bicubic2d_out_cpu 12615 CUDA: upsample_bicubic2d_out_cuda 12616 12617- func: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor 12618 python_module: nn 12619 structured_delegate: upsample_bicubic2d.out 12620 12621- func: upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12622 python_module: nn 12623 structured: True 12624 dispatch: 12625 CPU: upsample_bicubic2d_backward_out_cpu 12626 CUDA: upsample_bicubic2d_backward_out_cuda 12627 12628- func: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor 12629 python_module: nn 12630 structured_delegate: upsample_bicubic2d_backward.grad_input 12631 12632- func: _upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12633 python_module: nn 12634 structured: True 12635 dispatch: 12636 CPU: _upsample_bicubic2d_aa_out_cpu 12637 CUDA: _upsample_bicubic2d_aa_out_cuda 12638 12639- func: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor 12640 python_module: nn 12641 structured_delegate: _upsample_bicubic2d_aa.out 12642 12643- func: _upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12644 python_module: nn 12645 structured: True 12646 dispatch: 12647 CPU: _upsample_bicubic2d_aa_backward_out_cpu 12648 CUDA: _upsample_bicubic2d_aa_backward_out_cuda 12649 12650- func: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor 12651 python_module: nn 12652 structured_delegate: _upsample_bicubic2d_aa_backward.grad_input 12653 12654- func: upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12655 python_module: nn 12656 structured: True 12657 dispatch: 12658 CPU: upsample_trilinear3d_out_cpu 12659 CUDA: upsample_trilinear3d_out_cuda 12660 12661- func: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor 12662 python_module: nn 12663 structured_delegate: upsample_trilinear3d.out 12664 12665- func: upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12666 python_module: nn 12667 structured: True 12668 dispatch: 12669 CPU: upsample_trilinear3d_backward_out_cpu 12670 CUDA: upsample_trilinear3d_backward_out_cuda 12671 12672- func: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor 12673 python_module: nn 12674 structured_delegate: upsample_trilinear3d_backward.grad_input 12675 12676- func: upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) 12677 python_module: nn 12678 structured: True 12679 dispatch: 12680 CPU: upsample_nearest1d_out_cpu 12681 CUDA: upsample_nearest1d_out_cuda 12682 MPS: upsample_nearest1d_out_mps 12683 12684- func: _upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) 12685 python_module: nn 12686 structured: True 12687 dispatch: 12688 CPU: _upsample_nearest_exact1d_out_cpu 12689 CUDA: _upsample_nearest_exact1d_out_cuda 12690 MPS: _upsample_nearest_exact1d_out_mps 12691 12692- func: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor 12693 python_module: nn 12694 structured_delegate: upsample_nearest1d.out 12695 12696- func: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor 12697 python_module: nn 12698 structured_delegate: _upsample_nearest_exact1d.out 12699 12700- func: upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12701 python_module: nn 12702 structured: True 12703 dispatch: 12704 CPU: upsample_nearest1d_backward_out_cpu 12705 CUDA: upsample_nearest1d_backward_out_cuda 12706 MPS: upsample_nearest1d_backward_out_mps 12707 12708- func: _upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12709 python_module: nn 12710 structured: True 12711 dispatch: 12712 CPU: _upsample_nearest_exact1d_backward_out_cpu 12713 CUDA: _upsample_nearest_exact1d_backward_out_cuda 12714 MPS: _upsample_nearest_exact1d_backward_out_mps 12715 12716- func: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor 12717 python_module: nn 12718 structured_delegate: upsample_nearest1d_backward.grad_input 12719 12720- func: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor 12721 python_module: nn 12722 structured_delegate: _upsample_nearest_exact1d_backward.grad_input 12723 12724- func: upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12725 python_module: nn 12726 structured: True 12727 dispatch: 12728 CPU: upsample_nearest2d_out_cpu 12729 CUDA: upsample_nearest2d_out_cuda 12730 MPS: upsample_nearest2d_out_mps 12731 12732- func: _upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12733 python_module: nn 12734 structured: True 12735 dispatch: 12736 CPU: _upsample_nearest_exact2d_out_cpu 12737 CUDA: _upsample_nearest_exact2d_out_cuda 12738 MPS: _upsample_nearest_exact2d_out_mps 12739 12740- func: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor 12741 python_module: nn 12742 structured_delegate: upsample_nearest2d.out 12743 dispatch: 12744 QuantizedCPU: upsample_nearest2d_quantized_cpu 12745 12746- func: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor 12747 python_module: nn 12748 structured_delegate: _upsample_nearest_exact2d.out 12749 dispatch: 12750 QuantizedCPU: _upsample_nearest_exact2d_quantized_cpu 12751 12752- func: upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12753 python_module: nn 12754 structured: True 12755 dispatch: 12756 CPU: upsample_nearest2d_backward_out_cpu 12757 CUDA: upsample_nearest2d_backward_out_cuda 12758 MPS: upsample_nearest2d_backward_out_mps 12759 12760- func: _upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12761 python_module: nn 12762 structured: True 12763 dispatch: 12764 CPU: _upsample_nearest_exact2d_backward_out_cpu 12765 CUDA: _upsample_nearest_exact2d_backward_out_cuda 12766 MPS: _upsample_nearest_exact2d_backward_out_mps 12767 12768- func: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor 12769 python_module: nn 12770 structured_delegate: upsample_nearest2d_backward.grad_input 12771 12772- func: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor 12773 python_module: nn 12774 structured_delegate: _upsample_nearest_exact2d_backward.grad_input 12775 12776- func: upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12777 python_module: nn 12778 structured: True 12779 dispatch: 12780 CPU: upsample_nearest3d_out_cpu 12781 CUDA: upsample_nearest3d_out_cuda 12782 12783- func: _upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) 12784 python_module: nn 12785 structured: True 12786 dispatch: 12787 CPU: _upsample_nearest_exact3d_out_cpu 12788 CUDA: _upsample_nearest_exact3d_out_cuda 12789 12790- func: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor 12791 python_module: nn 12792 structured_delegate: upsample_nearest3d.out 12793 dispatch: 12794 QuantizedCPU: upsample_nearest3d_quantized_cpu 12795 12796- func: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor 12797 python_module: nn 12798 structured_delegate: _upsample_nearest_exact3d.out 12799 dispatch: 12800 QuantizedCPU: _upsample_nearest_exact3d_quantized_cpu 12801 12802- func: upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12803 python_module: nn 12804 structured: True 12805 dispatch: 12806 CPU: upsample_nearest3d_backward_out_cpu 12807 CUDA: upsample_nearest3d_backward_out_cuda 12808 12809- func: _upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12810 python_module: nn 12811 structured: True 12812 dispatch: 12813 CPU: _upsample_nearest_exact3d_backward_out_cpu 12814 CUDA: _upsample_nearest_exact3d_backward_out_cuda 12815 12816- func: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor 12817 python_module: nn 12818 structured_delegate: upsample_nearest3d_backward.grad_input 12819 12820- func: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor 12821 python_module: nn 12822 structured_delegate: _upsample_nearest_exact3d_backward.grad_input 12823 12824- func: sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) 12825 python_module: nn 12826 structured: True 12827 structured_inherits: TensorIteratorBase 12828 dispatch: 12829 CPU, CUDA: sigmoid_backward_out 12830 MPS: sigmoid_backward_out_mps 12831 tags: pointwise 12832 12833- func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor 12834 python_module: nn 12835 structured_delegate: sigmoid_backward.grad_input 12836 tags: pointwise 12837 12838- func: logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) 12839 python_module: nn 12840 structured: True 12841 structured_inherits: TensorIteratorBase 12842 dispatch: 12843 CPU, CUDA: logit_backward_out 12844 MPS: logit_backward_out_mps 12845 tags: pointwise 12846 12847- func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor 12848 python_module: nn 12849 structured_delegate: logit_backward.grad_input 12850 tags: pointwise 12851 12852- func: tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) 12853 python_module: nn 12854 structured: True 12855 structured_inherits: TensorIteratorBase 12856 dispatch: 12857 CPU, CUDA: tanh_backward_out 12858 MPS: tanh_backward_out_mps 12859 tags: pointwise 12860 12861- func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor 12862 python_module: nn 12863 structured_delegate: tanh_backward.grad_input 12864 12865# What's a thnn_conv_ versus a slow_conv_? 12866# 12867# Historically, we have inefficient implementations of convolutions 12868# coming from the THNN/THCUNN library. These convolutions typically 12869# operated by computing the Toeplitz matrix and then doing a matrix 12870# multiply with the input; this is very memory inefficient! However, 12871# occasionally, we really don't have anything better, so it's helpful 12872# to have these fallbacks when there is no more optimized implementation 12873# in cudnn or mkldnn, etc. Both thnn_ and slow_ convolutions fall 12874# into this bucket. 12875# 12876# The difference between these two designations, is that thnn_ refers 12877# to a convolution that is still written in the "legacy" style; that is, 12878# C code in the THNN/ or THCUNN/ directory. A slow_ convolution is 12879# one that is written in the native style: modern C++. Algorithmically, 12880# these are the same thing, but we give them different prefixes to 12881# make the operational distinction clear. 12882 tags: pointwise 12883 12884- func: slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) 12885 python_module: nn 12886 structured: True 12887 dispatch: 12888 CPU: slow_conv_transpose2d_structured_cpu 12889 CUDA: slow_conv_transpose2d_structured_cuda 12890 12891- func: slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor 12892 python_module: nn 12893 structured_delegate: slow_conv_transpose2d.out 12894 12895- func: slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) 12896 python_module: nn 12897 dispatch: 12898 CPU: slow_conv_transpose3d_out_cpu 12899 CUDA: slow_conv_transpose3d_out_cuda 12900 12901- func: slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor 12902 python_module: nn 12903 dispatch: 12904 CPU: slow_conv_transpose3d_cpu 12905 CUDA: slow_conv_transpose3d_cuda 12906 12907- func: thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) 12908 python_module: nn 12909 12910- func: thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor 12911 python_module: nn 12912 12913- func: _slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) 12914 python_module: nn 12915 dispatch: 12916 CPU: slow_conv2d_forward_out_cpu 12917 CUDA: slow_conv2d_forward_out_cuda 12918 12919- func: _slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor 12920 python_module: nn 12921 dispatch: 12922 CPU: slow_conv2d_forward_cpu 12923 CUDA: slow_conv2d_forward_cuda 12924 12925- func: _slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) 12926 python_module: nn 12927 dispatch: 12928 CPU: slow_conv2d_backward_out_cpu 12929 CUDA: slow_conv2d_backward_out_cuda 12930 12931- func: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) 12932 python_module: nn 12933 dispatch: 12934 CPU: slow_conv2d_backward_cpu 12935 CUDA: slow_conv2d_backward_cuda 12936 autogen: _slow_conv2d_backward.output_mask_out 12937 12938- func: _conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) 12939 use_const_ref_for_mutable_tensors: True 12940 python_module: nn 12941 dispatch: 12942 CUDA: conv_depthwise2d_cuda_out 12943 12944- func: _conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor 12945 python_module: nn 12946 dispatch: 12947 CUDA: conv_depthwise2d_cuda 12948 12949- func: conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor 12950 python_module: nn 12951 dispatch: 12952 CUDA: conv_depthwise3d_cuda 12953 autogen: conv_depthwise3d.out 12954 12955- func: slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) 12956 python_module: nn 12957 12958- func: slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor 12959 python_module: nn 12960 12961- func: slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) 12962 python_module: nn 12963 dispatch: 12964 CPU: slow_conv3d_forward_out_cpu 12965 12966- func: slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor 12967 python_module: nn 12968 dispatch: 12969 CPU: slow_conv3d_forward_cpu 12970 12971- func: slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor 12972 python_module: nn 12973 dispatch: 12974 CPU: slow_conv_dilated2d_cpu 12975 CUDA: slow_conv_dilated2d_cuda 12976 autogen: slow_conv_dilated2d.out 12977 12978- func: slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor 12979 python_module: nn 12980 dispatch: 12981 CPU: slow_conv_dilated3d_cpu 12982 CUDA: slow_conv_dilated3d_cuda 12983 autogen: slow_conv_dilated3d.out 12984 12985- func: col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) 12986 python_module: nn 12987 dispatch: 12988 CPU: col2im_out_cpu 12989 CUDA: col2im_out_cuda 12990 12991- func: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor 12992 python_module: nn 12993 dispatch: 12994 CPU: col2im_cpu 12995 CUDA: col2im_cuda 12996 tags: core 12997 12998- func: column_stack(Tensor[] tensors) -> Tensor 12999 13000- func: column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) 13001 13002- func: im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) 13003 python_module: nn 13004 dispatch: 13005 CPU: im2col_out_cpu 13006 CUDA: im2col_out_cuda 13007 13008- func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor 13009 python_module: nn 13010 dispatch: 13011 CPU: im2col_cpu 13012 CUDA: im2col_cuda 13013 13014- func: isfinite(Tensor self) -> Tensor 13015 variants: function, method 13016 device_check: NoCheck 13017 device_guard: False 13018 13019- func: isinf(Tensor self) -> Tensor 13020 variants: function, method 13021 device_check: NoCheck 13022 device_guard: False 13023 dispatch: 13024 CompositeExplicitAutograd: isinf 13025 SparseCPU, SparseCUDA: isinf_sparse 13026 SparseMeta: isinf_sparse_meta 13027 SparseCsrCPU, SparseCsrCUDA: isinf_sparse_csr 13028 autogen: isinf.out 13029 tags: [core, pointwise] 13030 13031- func: record_stream(Tensor(a!) self, Stream s) -> () 13032 variants: method 13033 dispatch: 13034 CUDA: record_stream_cuda 13035 13036- func: isposinf(Tensor self) -> Tensor 13037 variants: function, method 13038 structured_delegate: isposinf.out 13039 dispatch: 13040 SparseCPU, SparseCUDA: isposinf_sparse 13041 SparseCsrCPU, SparseCsrCUDA: isposinf_sparse_csr 13042 tags: pointwise 13043 13044- func: isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13045 structured: True 13046 structured_inherits: TensorIteratorBase 13047 dispatch: 13048 CPU, CUDA: isposinf_out 13049 SparseCPU, SparseCUDA: isposinf_sparse_out 13050 SparseCsrCPU, SparseCsrCUDA: isposinf_sparse_csr_out 13051 tags: pointwise 13052 13053- func: isneginf(Tensor self) -> Tensor 13054 variants: function, method 13055 structured_delegate: isneginf.out 13056 dispatch: 13057 SparseCPU, SparseCUDA: isneginf_sparse 13058 SparseCsrCPU, SparseCsrCUDA: isneginf_sparse_csr 13059 tags: pointwise 13060 13061- func: isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13062 structured: True 13063 structured_inherits: TensorIteratorBase 13064 dispatch: 13065 CPU, CUDA: isneginf_out 13066 SparseCPU, SparseCUDA: isneginf_sparse_out 13067 SparseCsrCPU, SparseCsrCUDA: isneginf_sparse_csr_out 13068 tags: pointwise 13069 13070# NOTE [_add_batch_dim and _remove_batch_dim] 13071# _add_batch_dim and _remove_batch_dim are meant to be used in the implementation 13072# of the vmap frontend API (see torch/_vmap_internals.py). They are not 13073# user-facing, hence the leading underscore. Please don't use them them anywhere else. 13074- func: _add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor 13075 variants: function 13076 13077# See NOTE [_add_batch_dim and _remove_batch_dim] 13078- func: _remove_batch_dim(Tensor self, int level, int batch_size, int out_dim) -> Tensor 13079 variants: function 13080 13081## Functions related to the `torch.special` namespace 13082# Note [special namespace binding] 13083# Functions in the special python module should have their names start with 13084# "special_" underscore and be bound to the desired Python name in 13085# torch/special/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/special.h. 13086# The "special_" names should be hidden from the user and not documented. 13087 13088- func: special_entr(Tensor self) -> Tensor 13089 structured_delegate: special_entr.out 13090 python_module: special 13091 variants: function 13092 tags: pointwise 13093 13094- func: special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13095 structured: True 13096 structured_inherits: TensorIteratorBase 13097 python_module: special 13098 variants: function 13099 dispatch: 13100 CPU, CUDA: special_entr_out 13101 tags: pointwise 13102 13103- func: special_ndtri(Tensor self) -> Tensor 13104 structured_delegate: special_ndtri.out 13105 python_module: special 13106 variants: function 13107 tags: pointwise 13108 13109- func: special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13110 structured: True 13111 structured_inherits: TensorIteratorBase 13112 python_module: special 13113 variants: function 13114 dispatch: 13115 CPU, CUDA: special_ndtri_out 13116 tags: pointwise 13117 13118- func: special_log_ndtr(Tensor self) -> Tensor 13119 structured_delegate: special_log_ndtr.out 13120 python_module: special 13121 variants: function 13122 tags: pointwise 13123 13124- func: special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13125 structured: True 13126 structured_inherits: TensorIteratorBase 13127 python_module: special 13128 variants: function 13129 dispatch: 13130 CPU, CUDA: special_log_ndtr_out 13131 tags: pointwise 13132 13133- func: special_expm1(Tensor self) -> Tensor 13134 python_module: special 13135 variants: function 13136 13137- func: special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13138 python_module: special 13139 variants: function 13140 13141- func: special_exp2(Tensor self) -> Tensor 13142 python_module: special 13143 variants: function 13144 13145- func: special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13146 python_module: special 13147 variants: function 13148 13149- func: special_psi(Tensor self) -> Tensor 13150 python_module: special 13151 variants: function 13152 13153- func: special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13154 python_module: special 13155 variants: function 13156 13157- func: special_digamma(Tensor self) -> Tensor 13158 python_module: special 13159 variants: function 13160 13161- func: special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13162 python_module: special 13163 variants: function 13164 13165- func: special_gammaln(Tensor self) -> Tensor 13166 python_module: special 13167 variants: function 13168 13169- func: special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13170 python_module: special 13171 variants: function 13172 13173- func: special_erf(Tensor self) -> Tensor 13174 python_module: special 13175 variants: function 13176 13177- func: special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13178 python_module: special 13179 variants: function 13180 13181- func: special_erfc(Tensor self) -> Tensor 13182 python_module: special 13183 variants: function 13184 13185- func: special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13186 python_module: special 13187 13188- func: special_erfcx(Tensor self) -> Tensor 13189 python_module: special 13190 variants: function 13191 structured_delegate: special_erfcx.out 13192 tags: pointwise 13193 13194- func: special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13195 python_module: special 13196 structured: True 13197 structured_inherits: TensorIteratorBase 13198 dispatch: 13199 CPU, CUDA: special_erfcx_out 13200 tags: pointwise 13201 13202- func: special_erfinv(Tensor self) -> Tensor 13203 python_module: special 13204 variants: function 13205 13206- func: special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13207 python_module: special 13208 13209- func: special_ndtr(Tensor self) -> Tensor 13210 python_module: special 13211 variants: function 13212 13213- func: special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13214 python_module: special 13215 variants: function 13216 13217- func: special_xlog1py(Tensor self, Tensor other) -> Tensor 13218 device_check: NoCheck # TensorIterator 13219 python_module: special 13220 variants: function 13221 structured_delegate: special_xlog1py.out 13222 tags: pointwise 13223 13224- func: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor 13225 device_check: NoCheck # TensorIterator 13226 python_module: special 13227 variants: function 13228 dispatch: 13229 CompositeExplicitAutograd: special_xlog1py 13230 tags: pointwise 13231 13232- func: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor 13233 device_check: NoCheck # TensorIterator 13234 python_module: special 13235 variants: function 13236 dispatch: 13237 CompositeExplicitAutograd: special_xlog1py 13238 tags: pointwise 13239 13240- func: special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13241 device_check: NoCheck # TensorIterator 13242 structured: True 13243 structured_inherits: TensorIteratorBase 13244 python_module: special 13245 variants: function 13246 dispatch: 13247 CPU, CUDA: special_xlog1py_out 13248 tags: pointwise 13249 13250- func: special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13251 device_check: NoCheck # TensorIterator 13252 python_module: special 13253 variants: function 13254 dispatch: 13255 CompositeExplicitAutograd: special_xlog1py_out 13256 tags: pointwise 13257 13258- func: special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 13259 device_check: NoCheck # TensorIterator 13260 python_module: special 13261 variants: function 13262 dispatch: 13263 CompositeExplicitAutograd: special_xlog1py_out 13264 tags: pointwise 13265 13266- func: special_xlogy(Tensor self, Tensor other) -> Tensor 13267 device_check: NoCheck # TensorIterator 13268 python_module: special 13269 variants: function 13270 13271- func: special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor 13272 device_check: NoCheck # TensorIterator 13273 python_module: special 13274 variants: function 13275 13276- func: special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor 13277 device_check: NoCheck # TensorIterator 13278 python_module: special 13279 variants: function 13280 13281- func: special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13282 device_check: NoCheck # TensorIterator 13283 python_module: special 13284 variants: function 13285 13286- func: special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13287 device_check: NoCheck # TensorIterator 13288 python_module: special 13289 variants: function 13290 13291- func: special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 13292 device_check: NoCheck # TensorIterator 13293 python_module: special 13294 variants: function 13295 13296- func: special_zeta(Tensor self, Tensor other) -> Tensor 13297 device_check: NoCheck # TensorIterator 13298 python_module: special 13299 variants: function 13300 structured_delegate: special_zeta.out 13301 tags: pointwise 13302 13303- func: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor 13304 device_check: NoCheck # TensorIterator 13305 python_module: special 13306 variants: function 13307 dispatch: 13308 CompositeExplicitAutograd: special_zeta 13309 tags: pointwise 13310 13311- func: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor 13312 device_check: NoCheck # TensorIterator 13313 python_module: special 13314 variants: function 13315 dispatch: 13316 CompositeExplicitAutograd: special_zeta 13317 tags: pointwise 13318 13319- func: special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13320 device_check: NoCheck # TensorIterator 13321 structured: True 13322 structured_inherits: TensorIteratorBase 13323 python_module: special 13324 variants: function 13325 dispatch: 13326 CPU, CUDA: special_zeta_out 13327 tags: pointwise 13328 13329- func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13330 device_check: NoCheck # TensorIterator 13331 python_module: special 13332 variants: function 13333 dispatch: 13334 CompositeExplicitAutograd: special_zeta_out 13335 tags: pointwise 13336 13337- func: special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) 13338 device_check: NoCheck # TensorIterator 13339 python_module: special 13340 variants: function 13341 dispatch: 13342 CompositeExplicitAutograd: special_zeta_out 13343 tags: pointwise 13344 13345- func: special_i0(Tensor self) -> Tensor 13346 python_module: special 13347 variants: function 13348 13349- func: special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13350 python_module: special 13351 variants: function 13352 13353- func: special_i0e(Tensor self) -> Tensor 13354 python_module: special 13355 variants: function 13356 structured_delegate: special_i0e.out 13357 tags: pointwise 13358 13359- func: special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13360 python_module: special 13361 structured: True 13362 structured_inherits: TensorIteratorBase 13363 dispatch: 13364 CPU, CUDA: special_i0e_out 13365 tags: pointwise 13366 13367- func: special_i1(Tensor self) -> Tensor 13368 python_module: special 13369 variants: function 13370 structured_delegate: special_i1.out 13371 tags: pointwise 13372 13373- func: special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13374 python_module: special 13375 structured: True 13376 structured_inherits: TensorIteratorBase 13377 dispatch: 13378 CPU, CUDA: special_i1_out 13379 tags: pointwise 13380 13381- func: special_i1e(Tensor self) -> Tensor 13382 python_module: special 13383 variants: function 13384 structured_delegate: special_i1e.out 13385 tags: pointwise 13386 13387- func: special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13388 python_module: special 13389 structured: True 13390 structured_inherits: TensorIteratorBase 13391 dispatch: 13392 CPU, CUDA: special_i1e_out 13393 tags: pointwise 13394 13395- func: special_logit(Tensor self, float? eps=None) -> Tensor 13396 python_module: special 13397 variants: function 13398 13399- func: special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) 13400 python_module: special 13401 13402- func: special_polygamma(int n, Tensor self) -> Tensor 13403 python_module: special 13404 variants: function 13405 13406- func: special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13407 python_module: special 13408 13409- func: special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor 13410 python_module: special 13411 variants: function 13412 13413- func: special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) 13414 python_module: special 13415 13416- func: special_expit(Tensor self) -> Tensor 13417 python_module: special 13418 variants: function 13419 13420- func: special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13421 python_module: special 13422 variants: function 13423 13424- func: special_sinc(Tensor self) -> Tensor 13425 python_module: special 13426 variants: function 13427 13428- func: special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13429 python_module: special 13430 variants: function 13431 13432- func: special_round(Tensor self, *, int decimals=0) -> Tensor 13433 python_module: special 13434 variants: function 13435 13436- func: special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!) 13437 python_module: special 13438 variants: function 13439 13440- func: special_log1p(Tensor self) -> Tensor 13441 python_module: special 13442 variants: function 13443 13444- func: special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13445 python_module: special 13446 variants: function 13447 13448- func: special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor 13449 python_module: special 13450 variants: function 13451 13452- func: special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13453 python_module: special 13454 variants: function 13455 13456- func: special_gammainc(Tensor self, Tensor other) -> Tensor 13457 python_module: special 13458 variants: function 13459 13460- func: special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13461 python_module: special 13462 variants: function 13463 13464- func: special_gammaincc(Tensor self, Tensor other) -> Tensor 13465 python_module: special 13466 variants: function 13467 13468- func: special_multigammaln(Tensor self, int p) -> Tensor 13469 python_module: special 13470 variants: function 13471 13472- func: special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) 13473 python_module: special 13474 variants: function 13475 13476- func: special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor 13477 python_module: special 13478 variants: function 13479 13480## Functions related to the fast Fourier transform and the torch.fft namespace 13481# Note [FFT namespace binding] 13482# Functions in the fft python module should have their names start with 13483# "fft_" underscore and be bound to the desired Python name in 13484# torch/fft/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/fft.h. 13485# The "fft_" names should be hidden from the user and not documented. 13486# 13487# See fft_fft as an example. 13488 13489# torch.fft.fft 13490# NOTE: NOT an alias for torch.fft, which has different semantics 13491- func: fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor 13492 python_module: fft 13493 variants: function 13494 dispatch: 13495 CompositeImplicitAutograd: fft_fft_symint 13496 13497- func: fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13498 python_module: fft 13499 variants: function 13500 dispatch: 13501 CompositeImplicitAutograd: fft_fft_symint_out 13502 13503- func: fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor 13504 python_module: fft 13505 variants: function 13506 dispatch: 13507 CompositeImplicitAutograd: fft_ifft_symint 13508 13509- func: fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13510 python_module: fft 13511 variants: function 13512 dispatch: 13513 CompositeImplicitAutograd: fft_ifft_symint_out 13514 13515- func: fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor 13516 python_module: fft 13517 variants: function 13518 dispatch: 13519 CompositeImplicitAutograd: fft_rfft_symint 13520 13521- func: fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13522 python_module: fft 13523 variants: function 13524 dispatch: 13525 CompositeImplicitAutograd: fft_rfft_symint_out 13526 13527- func: fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor 13528 python_module: fft 13529 variants: function 13530 dispatch: 13531 CompositeImplicitAutograd: fft_irfft_symint 13532 13533- func: fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13534 python_module: fft 13535 variants: function 13536 dispatch: 13537 CompositeImplicitAutograd: fft_irfft_symint_out 13538 13539- func: fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor 13540 python_module: fft 13541 variants: function 13542 dispatch: 13543 CompositeImplicitAutograd: fft_hfft_symint 13544 13545- func: fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13546 python_module: fft 13547 variants: function 13548 dispatch: 13549 CompositeImplicitAutograd: fft_hfft_symint_out 13550 13551- func: fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor 13552 python_module: fft 13553 variants: function 13554 dispatch: 13555 CompositeImplicitAutograd: fft_ihfft_symint 13556 13557- func: fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13558 python_module: fft 13559 variants: function 13560 dispatch: 13561 CompositeImplicitAutograd: fft_ihfft_symint_out 13562 13563- func: fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor 13564 python_module: fft 13565 variants: function 13566 dispatch: 13567 CompositeImplicitAutograd: fft_fft2_symint 13568 13569- func: fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13570 python_module: fft 13571 variants: function 13572 dispatch: 13573 CompositeImplicitAutograd: fft_fft2_symint_out 13574 13575- func: fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor 13576 python_module: fft 13577 variants: function 13578 dispatch: 13579 CompositeImplicitAutograd: fft_ifft2_symint 13580 13581- func: fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13582 python_module: fft 13583 variants: function 13584 dispatch: 13585 CompositeImplicitAutograd: fft_ifft2_symint_out 13586 13587- func: fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor 13588 python_module: fft 13589 variants: function 13590 dispatch: 13591 CompositeImplicitAutograd: fft_rfft2_symint 13592 13593- func: fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13594 python_module: fft 13595 variants: function 13596 dispatch: 13597 CompositeImplicitAutograd: fft_rfft2_symint_out 13598 13599- func: fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor 13600 python_module: fft 13601 variants: function 13602 dispatch: 13603 CompositeImplicitAutograd: fft_irfft2_symint 13604 13605- func: fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13606 python_module: fft 13607 variants: function 13608 dispatch: 13609 CompositeImplicitAutograd: fft_irfft2_symint_out 13610 13611- func: fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor 13612 use_const_ref_for_mutable_tensors: True 13613 python_module: fft 13614 variants: function 13615 dispatch: 13616 CompositeImplicitAutograd: fft_hfft2_symint 13617 13618- func: fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13619 use_const_ref_for_mutable_tensors: True 13620 python_module: fft 13621 variants: function 13622 dispatch: 13623 CompositeImplicitAutograd: fft_hfft2_symint_out 13624 13625- func: fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor 13626 use_const_ref_for_mutable_tensors: True 13627 python_module: fft 13628 variants: function 13629 dispatch: 13630 CompositeImplicitAutograd: fft_ihfft2_symint 13631 13632- func: fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13633 use_const_ref_for_mutable_tensors: True 13634 python_module: fft 13635 variants: function 13636 dispatch: 13637 CompositeImplicitAutograd: fft_ihfft2_symint_out 13638 13639- func: fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor 13640 python_module: fft 13641 variants: function 13642 dispatch: 13643 CompositeImplicitAutograd: fft_fftn_symint 13644 13645- func: fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13646 python_module: fft 13647 variants: function 13648 dispatch: 13649 CompositeImplicitAutograd: fft_fftn_symint_out 13650 13651- func: fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor 13652 python_module: fft 13653 variants: function 13654 dispatch: 13655 CompositeImplicitAutograd: fft_ifftn_symint 13656 13657- func: fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13658 python_module: fft 13659 variants: function 13660 dispatch: 13661 CompositeImplicitAutograd: fft_ifftn_symint_out 13662 13663- func: fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor 13664 python_module: fft 13665 variants: function 13666 dispatch: 13667 CompositeImplicitAutograd: fft_rfftn_symint 13668 13669- func: fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13670 python_module: fft 13671 variants: function 13672 dispatch: 13673 CompositeImplicitAutograd: fft_rfftn_symint_out 13674 13675- func: fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor 13676 python_module: fft 13677 variants: function 13678 dispatch: 13679 CompositeImplicitAutograd: fft_irfftn_symint 13680 13681- func: fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13682 python_module: fft 13683 variants: function 13684 dispatch: 13685 CompositeImplicitAutograd: fft_irfftn_symint_out 13686 13687- func: fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor 13688 use_const_ref_for_mutable_tensors: True 13689 python_module: fft 13690 variants: function 13691 dispatch: 13692 CompositeImplicitAutograd: fft_hfftn_symint 13693 13694- func: fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13695 use_const_ref_for_mutable_tensors: True 13696 python_module: fft 13697 variants: function 13698 dispatch: 13699 CompositeImplicitAutograd: fft_hfftn_symint_out 13700 13701- func: fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor 13702 use_const_ref_for_mutable_tensors: True 13703 python_module: fft 13704 variants: function 13705 dispatch: 13706 CompositeImplicitAutograd: fft_ihfftn_symint 13707 13708- func: fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) 13709 use_const_ref_for_mutable_tensors: True 13710 python_module: fft 13711 variants: function 13712 dispatch: 13713 CompositeImplicitAutograd: fft_ihfftn_symint_out 13714 13715- func: fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 13716 python_module: fft 13717 variants: function 13718 dispatch: 13719 CompositeExplicitAutograd: fft_fftfreq 13720 13721- func: fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) 13722 python_module: fft 13723 variants: function 13724 dispatch: 13725 CompositeExplicitAutograd: fft_fftfreq_out 13726 13727- func: fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 13728 python_module: fft 13729 variants: function 13730 dispatch: 13731 CompositeExplicitAutograd: fft_rfftfreq 13732 13733- func: fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) 13734 python_module: fft 13735 variants: function 13736 dispatch: 13737 CompositeExplicitAutograd: fft_rfftfreq_out 13738 13739- func: fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor 13740 python_module: fft 13741 variants: function 13742 13743- func: fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor 13744 python_module: fft 13745 variants: function 13746 13747## Functions for linear algebra and the torch.linalg namespace 13748# Note [linalg namespace binding] 13749# Functions in the linalg python module should have their names start with 13750# "linalg_" and be bound to the desired Python name in 13751# torch/linalg/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/linalg.h. 13752# The "linalg_" names should be hidden from the user and not documented. 13753# 13754# See linalg_det as an example. 13755 13756# "_ex" stands for experimental 13757- func: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) 13758 python_module: linalg 13759 structured_delegate: linalg_cholesky_ex.L 13760 13761- func: linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) 13762 python_module: linalg 13763 structured: True 13764 dispatch: 13765 CPU, CUDA: linalg_cholesky_ex_out 13766 13767- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor 13768 python_module: linalg 13769 13770- func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) 13771 python_module: linalg 13772 13773- func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor 13774 python_module: linalg 13775 variants: function 13776 structured_delegate: linalg_cross.out 13777 dispatch: 13778 ZeroTensor: linalg_cross_zerotensor 13779 13780- func: linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) 13781 python_module: linalg 13782 structured: True 13783 dispatch: 13784 CPU, CUDA, MPS: linalg_cross_out 13785 13786# linalg.lu_factor 13787- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) 13788 python_module: linalg 13789 variants: function 13790 13791- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) 13792 python_module: linalg 13793 variants: function 13794 13795- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) 13796 python_module: linalg 13797 structured_delegate: linalg_lu_factor_ex.out 13798 variants: function 13799 13800- func: linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) 13801 python_module: linalg 13802 variants: function 13803 structured: True 13804 dispatch: 13805 CPU, CUDA: linalg_lu_factor_ex_out 13806 13807# linalg.lu 13808- func: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U) 13809 python_module: linalg 13810 structured_delegate: linalg_lu.out 13811 variants: function 13812 13813- func: linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) 13814 python_module: linalg 13815 variants: function 13816 structured: True 13817 dispatch: 13818 CPU, CUDA: linalg_lu_out 13819 13820# linalg.lu_solve 13821- func: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor 13822 python_module: linalg 13823 structured_delegate: linalg_lu_solve.out 13824 variants: function 13825 13826- func: linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!) 13827 python_module: linalg 13828 variants: function 13829 structured: True 13830 dispatch: 13831 CPU, CUDA: linalg_lu_solve_out 13832 13833# linalg.det 13834- func: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) 13835 structured_delegate: _linalg_det.result 13836 13837- func: _linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) 13838 structured: True 13839 dispatch: 13840 CPU, CUDA: _linalg_det_out 13841 13842- func: linalg_det(Tensor A) -> Tensor 13843 python_module: linalg 13844 variants: function 13845 13846- func: linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) 13847 python_module: linalg 13848 13849# torch.det, alias for torch.linalg.det 13850- func: det(Tensor self) -> Tensor 13851 variants: function, method 13852 13853- func: linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info) 13854 structured_delegate: linalg_ldl_factor_ex.out 13855 python_module: linalg 13856 variants: function 13857 13858- func: linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) 13859 structured: True 13860 python_module: linalg 13861 variants: function 13862 dispatch: 13863 CPU, CUDA: linalg_ldl_factor_ex_out 13864 13865- func: linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots) 13866 python_module: linalg 13867 variants: function 13868 13869- func: linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots) 13870 python_module: linalg 13871 variants: function 13872 13873- func: linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor 13874 structured_delegate: linalg_ldl_solve.out 13875 python_module: linalg 13876 variants: function 13877 13878- func: linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) 13879 structured: True 13880 python_module: linalg 13881 variants: function 13882 dispatch: 13883 CPU, CUDA: linalg_ldl_solve_out 13884 13885- func: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values) 13886 python_module: linalg 13887 variants: function 13888 dispatch: 13889 CompositeExplicitAutograd: linalg_lstsq 13890 tags: dynamic_output_shape 13891 13892- func: linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) 13893 python_module: linalg 13894 variants: function 13895 dispatch: 13896 CPU, CUDA: linalg_lstsq_out 13897 tags: dynamic_output_shape 13898 13899# torch.linalg.matmul, alias for torch.matmul 13900- func: linalg_matmul(Tensor self, Tensor other) -> Tensor 13901 python_module: linalg 13902 variants: function 13903 13904- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 13905 python_module: linalg 13906 13907- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor 13908 python_module: linalg 13909 variants: function 13910 13911- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) 13912 python_module: linalg 13913 13914- func: linalg_matrix_exp(Tensor self) -> Tensor 13915 python_module: linalg 13916 variants: function 13917 dispatch: 13918 CPU, CUDA: linalg_matrix_exp 13919 autogen: linalg_matrix_exp.out 13920 13921- func: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) 13922 structured_delegate: _linalg_slogdet.sign 13923 13924- func: _linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) 13925 structured: True 13926 dispatch: 13927 CPU, CUDA: _linalg_slogdet_out 13928 13929- func: linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet) 13930 python_module: linalg 13931 13932- func: linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) 13933 python_module: linalg 13934 13935- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) 13936 variants: function, method 13937 13938- func: slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) 13939 variants: function 13940 13941- func: logdet(Tensor self) -> Tensor 13942 variants: function, method 13943 13944- func: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) 13945 python_module: linalg 13946 variants: function 13947 dispatch: 13948 CPU, CUDA: linalg_eig 13949 13950- func: linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) 13951 python_module: linalg 13952 dispatch: 13953 CPU, CUDA: linalg_eig_out 13954 13955- func: _linalg_eigvals(Tensor self) -> Tensor 13956 python_module: linalg 13957 dispatch: 13958 CPU, CUDA: _linalg_eigvals 13959 13960- func: linalg_eigvals(Tensor self) -> Tensor 13961 python_module: linalg 13962 13963- func: linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 13964 python_module: linalg 13965 dispatch: 13966 CPU, CUDA: linalg_eigvals_out 13967 13968# This function is exposes the `compute_v` flag, which is then used to implement `linalg.eigh` and 13969# `linalg.eigvalsh` as composite functions that call this one 13970- func: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors) 13971 structured_delegate: _linalg_eigh.eigenvalues 13972 13973- func: _linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) 13974 structured: True 13975 dispatch: 13976 CPU, CUDA: _linalg_eigh_out 13977 13978- func: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) 13979 python_module: linalg 13980 13981- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) 13982 python_module: linalg 13983 13984- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor 13985 python_module: linalg 13986 13987- func: linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!) 13988 python_module: linalg 13989 13990- func: linalg_householder_product(Tensor input, Tensor tau) -> Tensor 13991 python_module: linalg 13992 variants: function 13993 dispatch: 13994 CPU, CUDA: linalg_householder_product 13995 13996- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) 13997 python_module: linalg 13998 dispatch: 13999 CPU, CUDA: linalg_householder_product_out 14000 14001- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) 14002 python_module: linalg 14003 structured_delegate: linalg_inv_ex.inverse 14004 14005- func: linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info) 14006 python_module: linalg 14007 structured: True 14008 dispatch: 14009 CPU, CUDA: linalg_inv_ex_out 14010 MPS: linalg_inv_ex_out_mps 14011 14012- func: linalg_inv(Tensor A) -> Tensor 14013 python_module: linalg 14014 14015- func: linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) 14016 python_module: linalg 14017 14018- func: inverse(Tensor self) -> Tensor 14019 variants: function, method 14020 14021- func: inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 14022 14023- func: inner(Tensor self, Tensor other) -> Tensor 14024 variants: function, method 14025 14026- func: inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 14027 14028- func: outer(Tensor self, Tensor vec2) -> Tensor 14029 variants: function, method 14030 14031- func: outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) 14032 14033# torch.ger, alias for torch.outer 14034- func: ger(Tensor self, Tensor vec2) -> Tensor 14035 variants: function, method 14036 14037- func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) 14038 14039- func: linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 14040 python_module: linalg 14041 variants: function 14042 14043- func: linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 14044 python_module: linalg 14045 variants: function 14046 14047- func: linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 14048 python_module: linalg 14049 variants: function 14050 14051- func: linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 14052 python_module: linalg 14053 variants: function 14054 14055- func: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 14056 python_module: linalg 14057 variants: function 14058 structured_delegate: linalg_vector_norm.out 14059 14060- func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 14061 python_module: linalg 14062 structured: True 14063 dispatch: 14064 CPU, CUDA: linalg_vector_norm_out 14065 MPS: linalg_vector_norm_out_mps 14066 14067- func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 14068 python_module: linalg 14069 14070- func: linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 14071 python_module: linalg 14072 14073- func: linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor 14074 python_module: linalg 14075 14076- func: linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) 14077 python_module: linalg 14078 14079# This function is exposes the `compute_uv` flag, which is then used to implement `linalg.svd` and 14080# `linalg.svdvals` as composite functions that call this one 14081- func: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) 14082 variants: function 14083 structured_delegate: _linalg_svd.U 14084 14085- func: _linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) 14086 structured: True 14087 dispatch: 14088 CPU, CUDA: _linalg_svd_out 14089 14090- func: linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) 14091 python_module: linalg 14092 variants: function 14093 14094- func: linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) 14095 python_module: linalg 14096 variants: function 14097 14098- func: linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor 14099 python_module: linalg 14100 variants: function 14101 14102- func: linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!) 14103 python_module: linalg 14104 variants: function 14105 14106- func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor 14107 python_module: linalg 14108 variants: function 14109 14110- func: linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!) 14111 python_module: linalg 14112 variants: function 14113 14114- func: linalg_cond.p_str(Tensor self, str p) -> Tensor 14115 python_module: linalg 14116 variants: function 14117 14118- func: linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!) 14119 python_module: linalg 14120 variants: function 14121 14122- func: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor 14123 python_module: linalg 14124 variants: function 14125 dispatch: 14126 # calls svd, which calls mH() (view op) 14127 # also calls narrow() 14128 CompositeExplicitAutogradNonFunctional: linalg_pinv 14129 14130- func: linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) 14131 python_module: linalg 14132 variants: function 14133 dispatch: 14134 CompositeExplicitAutograd: linalg_pinv_out 14135 14136- func: linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor 14137 cpp_no_default_args: ['atol', 'rtol'] 14138 python_module: linalg 14139 variants: function 14140 14141- func: linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) 14142 cpp_no_default_args: ['atol', 'rtol'] 14143 python_module: linalg 14144 variants: function 14145 14146- func: linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor 14147 python_module: linalg 14148 variants: function 14149 14150- func: linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor 14151 python_module: linalg 14152 variants: function 14153 14154- func: linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) 14155 python_module: linalg 14156 variants: function 14157 14158- func: linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) 14159 python_module: linalg 14160 variants: function 14161 14162- func: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) 14163 structured_delegate: _linalg_solve_ex.result 14164 14165- func: _linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) 14166 structured: True 14167 dispatch: 14168 CPU, CUDA: _linalg_solve_ex_out 14169 14170- func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info) 14171 python_module: linalg 14172 14173- func: linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info) 14174 python_module: linalg 14175 14176- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor 14177 python_module: linalg 14178 14179- func: linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!) 14180 python_module: linalg 14181 14182- func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor 14183 python_module: linalg 14184 variants: function 14185 14186- func: linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!) 14187 python_module: linalg 14188 variants: function 14189 14190- func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor 14191 python_module: linalg 14192 variants: function 14193 14194- func: linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) 14195 python_module: linalg 14196 variants: function 14197 14198- func: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) 14199 python_module: linalg 14200 variants: function 14201 structured_delegate: linalg_qr.out 14202 14203- func: linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) 14204 python_module: linalg 14205 structured: True 14206 dispatch: 14207 CPU, CUDA: linalg_qr_out 14208 14209- func: linalg_matrix_power(Tensor self, int n) -> Tensor 14210 python_module: linalg 14211 14212- func: linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) 14213 python_module: linalg 14214 14215- func: linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor 14216 python_module: linalg 14217 variants: function 14218 14219- func: linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) 14220 python_module: linalg 14221 variants: function 14222 14223- func: linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor 14224 cpp_no_default_args: ['atol', 'rtol'] 14225 python_module: linalg 14226 variants: function 14227 14228- func: linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) 14229 cpp_no_default_args: ['atol', 'rtol'] 14230 python_module: linalg 14231 variants: function 14232 14233- func: linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor 14234 python_module: linalg 14235 variants: function 14236 14237- func: linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) 14238 python_module: linalg 14239 variants: function 14240 14241- func: linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor 14242 python_module: linalg 14243 variants: function 14244 14245- func: linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) 14246 python_module: linalg 14247 variants: function 14248 14249- func: linalg_multi_dot(Tensor[] tensors) -> Tensor 14250 python_module: linalg 14251 14252- func: linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) 14253 python_module: linalg 14254 14255## Functions related to the `torch.nested` namespace 14256# Note [nested namespace binding] 14257# Functions in the nested python module should have their names start with 14258# "nested_" underscore and be bound to the desired Python name in 14259# torch/nested/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/nested.h. 14260# The "nested_" names should be hidden from the user and not documented. 14261 14262- func: nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor 14263 python_module: nested 14264 variants: function 14265 14266## Functions that are only for testing 14267# It is undocumented and should not be used outside of tests. 14268- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor 14269 14270# Note: for testing COW materialization within `at::parallel_for` loop function 14271- func: _test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor 14272 variants: function 14273 dispatch: 14274 CompositeExplicitAutograd: _test_parallel_materialize 14275 14276# Note: this function is only for testing. 14277- func: _test_optional_intlist(Tensor values, int[]? addends) -> Tensor 14278 python_module: nn 14279 dispatch: 14280 CPU: _test_optional_intlist 14281 autogen: _test_optional_intlist.out 14282 14283# Note: this function is only for testing. 14284- func: _test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor 14285 python_module: nn 14286 dispatch: 14287 CPU: _test_optional_intlist 14288 autogen: _test_optional_filled_intlist.out 14289 14290# Note: this function is only for testing. 14291- func: _test_optional_floatlist(Tensor values, float[]? addends) -> Tensor 14292 python_module: nn 14293 dispatch: 14294 CPU: _test_optional_floatlist 14295 autogen: _test_optional_floatlist.out 14296 14297# Note: this function is only for testing. 14298- func: _test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor 14299 python_module: nn 14300 14301# Note: this function is only for testing. 14302- func: _test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor 14303 python_module: nn 14304 14305# Note: this function is only for testing. 14306- func: _test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor 14307 cpp_no_default_args: ['a', 'b'] 14308 python_module: nn 14309 14310# Note: this function is only for testing. 14311- func: _test_warn_in_autograd(Tensor self) -> Tensor 14312 python_module: nn 14313 dispatch: 14314 CompositeExplicitAutograd: _test_warn_in_autograd 14315 autogen: _test_warn_in_autograd.out 14316 14317# Note: this function is only for testing. 14318- func: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor 14319 dispatch: 14320 # the NestedTensor keys are necessary because NestedTensor has been removed 14321 # from the CompositeExplicitAutograd keyset see Note [NestedTensor Not Included in Backend Keys] 14322 CompositeExplicitAutograd, NestedTensorCPU, NestedTensorCUDA: _test_autograd_multiple_dispatch_fullcoverage 14323 autogen: _test_autograd_multiple_dispatch.fullcoverage_out 14324 14325# Note: this function is only for testing. 14326- func: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor 14327 dispatch: 14328 CompositeImplicitAutograd, NestedTensorCPU, NestedTensorCUDA: _test_autograd_multiple_dispatch_ntonly 14329 14330# Note: this function is only for testing. 14331- func: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a) 14332 dispatch: 14333 CompositeExplicitAutograd: _test_autograd_multiple_dispatch_view 14334 14335# Note: this function is only for testing. 14336- func: _test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor 14337 variants: function 14338 dispatch: 14339 CompositeExplicitAutogradNonFunctional: _test_autograd_multiple_dispatch_view_copy 14340 tags: view_copy 14341 autogen: _test_autograd_multiple_dispatch_view_copy.out 14342 14343- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor 14344 variants: function 14345 dispatch: 14346 CPU, CUDA: segment_reduce_kernel 14347 autogen: segment_reduce.out 14348 14349- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor 14350 variants: function 14351 dispatch: 14352 CPU, CUDA: _segment_reduce_backward_kernel 14353 autogen: _segment_reduce_backward.out 14354 14355- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0) -> Tensor 14356 python_module: nn 14357 variants: function 14358 14359- func: flatten_dense_tensors(Tensor[] tensors) -> Tensor 14360 variants: function 14361 python_module: nn 14362 14363- func: unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[] 14364 variants: function 14365 python_module: nn 14366 14367- func: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor 14368 variants: function 14369 dispatch: 14370 CompositeExplicitAutograd: _nested_tensor_from_tensor_list 14371 autogen: _nested_tensor_from_tensor_list.out 14372 14373- func: _fw_primal_copy(Tensor self, int level) -> Tensor 14374 variants: function 14375 dispatch: 14376 CompositeExplicitAutogradNonFunctional: _fw_primal_copy 14377 tags: view_copy 14378 autogen: _fw_primal_copy.out 14379 14380- func: _make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor 14381 variants: function 14382 dispatch: 14383 CompositeExplicitAutogradNonFunctional: _make_dual_copy 14384 tags: view_copy 14385 autogen: _make_dual_copy.out 14386 14387- func: view_as_real_copy(Tensor self) -> Tensor 14388 variants: function 14389 dispatch: 14390 CompositeExplicitAutogradNonFunctional: view_as_real_copy 14391 tags: view_copy 14392 autogen: view_as_real_copy.out 14393 14394- func: view_as_complex_copy(Tensor self) -> Tensor 14395 variants: function 14396 dispatch: 14397 CompositeExplicitAutogradNonFunctional: view_as_complex_copy 14398 tags: view_copy 14399 autogen: view_as_complex_copy.out 14400 14401- func: _conj_copy(Tensor self) -> Tensor 14402 variants: function 14403 dispatch: 14404 CompositeExplicitAutogradNonFunctional: _conj_copy 14405 tags: view_copy 14406 autogen: _conj_copy.out 14407 14408- func: _neg_view_copy(Tensor self) -> Tensor 14409 variants: function 14410 dispatch: 14411 CompositeExplicitAutogradNonFunctional: _neg_view_copy 14412 tags: view_copy 14413 autogen: _neg_view_copy.out 14414 14415- func: as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor 14416 variants: function 14417 dispatch: 14418 CompositeExplicitAutogradNonFunctional: as_strided_copy_symint 14419 tags: view_copy 14420 autogen: as_strided_copy.out 14421 14422- func: _sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor 14423 variants: function 14424 dispatch: 14425 CompositeExplicitAutogradNonFunctional: _sparse_broadcast_to_copy 14426 tags: view_copy 14427 autogen: _sparse_broadcast_to_copy.out 14428 14429- func: diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor 14430 variants: function 14431 dispatch: 14432 CompositeExplicitAutogradNonFunctional: diagonal_copy 14433 tags: view_copy 14434 autogen: diagonal_copy.out 14435 14436- func: expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor 14437 variants: function 14438 dispatch: 14439 CompositeExplicitAutogradNonFunctional: expand_copy_symint 14440 tags: view_copy 14441 autogen: expand_copy.out 14442 14443- func: permute_copy(Tensor self, int[] dims) -> Tensor 14444 variants: function 14445 dispatch: 14446 CompositeExplicitAutogradNonFunctional: permute_copy 14447 tags: view_copy 14448 autogen: permute_copy.out 14449 14450- func: _reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor 14451 variants: function 14452 dispatch: 14453 CompositeExplicitAutogradNonFunctional: _reshape_alias_copy_symint 14454 tags: view_copy 14455 autogen: _reshape_alias_copy.out 14456 14457- func: select_copy.int(Tensor self, int dim, SymInt index) -> Tensor 14458 variants: function 14459 dispatch: 14460 CompositeExplicitAutogradNonFunctional: select_copy_symint 14461 SparseCsrCPU, SparseCsrCUDA: select_copy_sparse_csr 14462 tags: view_copy 14463 autogen: select_copy.int_out 14464 14465- func: detach_copy(Tensor self) -> Tensor 14466 variants: function 14467 dispatch: 14468 CompositeExplicitAutogradNonFunctional: detach_copy 14469 tags: view_copy 14470 autogen: detach_copy.out 14471 14472- func: slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor 14473 variants: function 14474 dispatch: 14475 CompositeExplicitAutogradNonFunctional: slice_copy_Tensor_symint 14476 tags: view_copy 14477 autogen: slice_copy.Tensor_out 14478 14479- func: split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] 14480 variants: function 14481 dispatch: 14482 CompositeExplicitAutogradNonFunctional: split_copy_Tensor_symint 14483 tags: view_copy 14484 14485- func: split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] 14486 variants: function 14487 dispatch: 14488 CompositeExplicitAutogradNonFunctional: split_with_sizes_copy_symint 14489 tags: view_copy 14490 14491- func: squeeze_copy(Tensor self) -> Tensor 14492 variants: function 14493 dispatch: 14494 CompositeExplicitAutogradNonFunctional: squeeze_copy 14495 tags: view_copy 14496 autogen: squeeze_copy.out 14497 14498- func: squeeze_copy.dim(Tensor self, int dim) -> Tensor 14499 variants: function 14500 dispatch: 14501 CompositeExplicitAutogradNonFunctional: squeeze_copy_dim 14502 tags: view_copy 14503 autogen: squeeze_copy.dim_out 14504 14505- func: squeeze_copy.dims(Tensor self, int[] dim) -> Tensor 14506 variants: function 14507 dispatch: 14508 CompositeExplicitAutogradNonFunctional: squeeze_copy_dims 14509 tags: view_copy 14510 autogen: squeeze_copy.dims_out 14511 14512- func: t_copy(Tensor self) -> Tensor 14513 variants: function 14514 dispatch: 14515 CompositeExplicitAutogradNonFunctional: t_copy 14516 tags: view_copy 14517 autogen: t_copy.out 14518 14519- func: transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor 14520 variants: function 14521 dispatch: 14522 CompositeExplicitAutogradNonFunctional: transpose_copy_int 14523 tags: view_copy 14524 autogen: transpose_copy.int_out 14525 14526- func: unsqueeze_copy(Tensor self, int dim) -> Tensor 14527 variants: function 14528 dispatch: 14529 CompositeExplicitAutogradNonFunctional: unsqueeze_copy 14530 tags: view_copy 14531 autogen: unsqueeze_copy.out 14532 14533- func: _indices_copy(Tensor self) -> Tensor 14534 variants: function 14535 dispatch: 14536 CompositeExplicitAutogradNonFunctional: _indices_copy 14537 tags: view_copy 14538 autogen: _indices_copy.out 14539 14540- func: _values_copy(Tensor self) -> Tensor 14541 variants: function 14542 dispatch: 14543 CompositeExplicitAutogradNonFunctional: _values_copy 14544 tags: view_copy 14545 autogen: _values_copy.out 14546 14547- func: indices_copy(Tensor self) -> Tensor 14548 variants: function 14549 dispatch: 14550 CompositeExplicitAutogradNonFunctional: indices_copy 14551 tags: view_copy 14552 autogen: indices_copy.out 14553 14554- func: values_copy(Tensor self) -> Tensor 14555 variants: function 14556 dispatch: 14557 CompositeExplicitAutogradNonFunctional: values_copy 14558 tags: view_copy 14559 autogen: values_copy.out 14560 14561- func: crow_indices_copy(Tensor self) -> Tensor 14562 variants: function 14563 dispatch: 14564 CompositeExplicitAutogradNonFunctional: crow_indices_copy 14565 tags: view_copy 14566 autogen: crow_indices_copy.out 14567 14568- func: col_indices_copy(Tensor self) -> Tensor 14569 variants: function 14570 dispatch: 14571 CompositeExplicitAutogradNonFunctional: col_indices_copy 14572 tags: view_copy 14573 autogen: col_indices_copy.out 14574 14575- func: ccol_indices_copy(Tensor self) -> Tensor 14576 variants: function 14577 dispatch: 14578 CompositeExplicitAutogradNonFunctional: ccol_indices_copy 14579 tags: view_copy 14580 autogen: ccol_indices_copy.out 14581 14582- func: row_indices_copy(Tensor self) -> Tensor 14583 variants: function 14584 dispatch: 14585 CompositeExplicitAutogradNonFunctional: row_indices_copy 14586 tags: view_copy 14587 autogen: row_indices_copy.out 14588 14589- func: unbind_copy.int(Tensor self, int dim=0) -> Tensor[] 14590 variants: function 14591 dispatch: 14592 CompositeExplicitAutogradNonFunctional: unbind_copy_int 14593 tags: view_copy 14594 14595- func: unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> () 14596 variants: function 14597 dispatch: 14598 CompositeExplicitAutograd: unbind_copy_int_out 14599 14600- func: split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () 14601 variants: function 14602 dispatch: 14603 CompositeExplicitAutograd: split_copy_Tensor_out 14604 14605 14606- func: split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () 14607 variants: function 14608 dispatch: 14609 CompositeExplicitAutograd: split_with_sizes_copy_out 14610 CUDA: split_with_sizes_copy_out_cuda 14611 14612- func: view_copy(Tensor self, SymInt[] size) -> Tensor 14613 variants: function 14614 dispatch: 14615 CompositeExplicitAutogradNonFunctional: view_copy_symint 14616 tags: view_copy 14617 autogen: view_copy.out 14618 14619- func: view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor 14620 variants: function 14621 dispatch: 14622 CompositeExplicitAutogradNonFunctional: view_copy_dtype 14623 tags: view_copy 14624 autogen: view_copy.dtype_out 14625 14626- func: unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor 14627 variants: function 14628 dispatch: 14629 CompositeExplicitAutogradNonFunctional: unfold_copy 14630 tags: view_copy 14631 autogen: unfold_copy.out 14632 14633- func: alias_copy(Tensor self) -> Tensor 14634 variants: function 14635 dispatch: 14636 CompositeExplicitAutogradNonFunctional: alias_copy 14637 tags: view_copy 14638 autogen: alias_copy.out 14639 14640- func: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor 14641 variants: method 14642 dispatch: 14643 NestedTensorCPU: NestedTensor_to_padded_tensor_generic 14644 NestedTensorCUDA: NestedTensor_to_padded_tensor_cuda 14645 autogen: to_padded_tensor.out 14646 14647- func: _jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor 14648 variants: function 14649 dispatch: 14650 CUDA: _fbgemm_jagged_to_padded_dense_forward 14651 14652- func: _padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor 14653 variants: function 14654 dispatch: 14655 CUDA: _fbgemm_dense_to_jagged_forward_symint 14656 14657- func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor 14658 dispatch: 14659 NestedTensorCPU: NestedTensor_softmax_dropout 14660 NestedTensorCUDA: NestedTensor_softmax_dropout_cuda 14661 tags: nondeterministic_seeded 14662 14663# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is. 14664- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor 14665 variants: function 14666 dispatch: 14667 CPU, CUDA, NestedTensorCPU, NestedTensorCUDA: transformer_encoder_layer_forward 14668 autogen: _transformer_encoder_layer_fwd.out 14669 14670- func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor) 14671 variants: function 14672 dispatch: 14673 CPU, NestedTensorCPU: native_multi_head_attention_cpu 14674 CUDA, NestedTensorCUDA: native_multi_head_attention_cuda 14675 autogen: _native_multi_head_attention.out 14676 14677- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor 14678 python_module: nn 14679 variants: function 14680 autogen: scaled_dot_product_attention.out 14681 tags: nondeterministic_seeded 14682 14683# This aten function is kept so that we can test the choice function from Python 14684- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> int 14685 dispatch: 14686 Meta: _fused_sdp_choice_meta 14687 CPU, NestedTensorCPU: _fused_sdp_choice_cpp 14688 CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda 14689 tags: nondeterministic_seeded 14690 14691- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor) 14692 variants: function 14693 tags: nondeterministic_seeded 14694 14695- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) 14696 dispatch: 14697 CUDA: _scaled_dot_product_flash_attention_cuda 14698 NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda 14699 tags: nondeterministic_seeded 14700 14701- func: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp) 14702 dispatch: 14703 CPU: _scaled_dot_product_flash_attention_cpu 14704 tags: nondeterministic_seeded 14705 14706- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) 14707 device_check: NoCheck 14708 variants: function 14709 dispatch: 14710 CUDA: _scaled_dot_product_flash_attention_backward_cuda 14711 NestedTensorCUDA: _scaled_dot_product_flash_attention_backward_nested 14712 14713- func: _scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) 14714 device_check: NoCheck 14715 variants: function 14716 dispatch: 14717 CPU: _scaled_dot_product_flash_attention_cpu_backward 14718 14719- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) 14720 dispatch: 14721 CUDA: _scaled_dot_product_efficient_attention_cuda 14722 NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda 14723 tags: nondeterministic_seeded 14724 14725- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) 14726 device_check: NoCheck 14727 dispatch: 14728 CUDA: _scaled_dot_product_efficient_attention_backward_cuda 14729 tags: nondeterministic_seeded 14730 14731- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) 14732 dispatch: 14733 CUDA: _scaled_dot_product_cudnn_attention_cuda 14734 tags: nondeterministic_seeded 14735 14736- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor) 14737 dispatch: 14738 CUDA: _scaled_dot_product_cudnn_attention_backward_cuda 14739 tags: nondeterministic_seeded 14740 14741- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) 14742 variants: function 14743 dispatch: 14744 CUDA: _flash_attention_forward 14745 tags: nondeterministic_seeded 14746 14747- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) 14748 device_check: NoCheck 14749 variants: function 14750 dispatch: 14751 CUDA: _flash_attention_backward 14752 14753# Returns output, logsumexp if compute_logsumexp 14754- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) 14755 variants: function 14756 dispatch: 14757 CUDA: _efficient_attention_forward 14758 tags: nondeterministic_seeded 14759 14760- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor) 14761 device_check: NoCheck 14762 variants: function 14763 dispatch: 14764 CUDA: _efficient_attention_backward 14765 14766- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor 14767 variants: function 14768 dispatch: 14769 CUDA: triton_scaled_dot_attention 14770 tags: nondeterministic_seeded 14771 autogen: _triton_scaled_dot_attention.out 14772 14773- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!) 14774 variants: function 14775 dispatch: 14776 CUDA: _fill_mem_eff_dropout_mask_ 14777 tags: nondeterministic_seeded 14778 14779- func: _triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor 14780 variants: function 14781 dispatch: 14782 CUDA: triton_multi_head_attention 14783 autogen: _triton_multi_head_attention.out 14784 14785- func: special_airy_ai(Tensor x) -> Tensor 14786 python_module: special 14787 structured_delegate: special_airy_ai.out 14788 variants: function 14789 tags: pointwise 14790 14791- func: special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) 14792 dispatch: 14793 CPU, CUDA: special_airy_ai_out 14794 python_module: special 14795 structured_inherits: TensorIteratorBase 14796 structured: True 14797 variants: function 14798 tags: pointwise 14799 14800- func: special_bessel_j0(Tensor self) -> Tensor 14801 python_module: special 14802 structured_delegate: special_bessel_j0.out 14803 variants: function 14804 tags: pointwise 14805 14806- func: special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 14807 dispatch: 14808 CPU, CUDA: special_bessel_j0_out 14809 python_module: special 14810 structured_inherits: TensorIteratorBase 14811 structured: True 14812 variants: function 14813 tags: pointwise 14814 14815- func: special_bessel_j1(Tensor self) -> Tensor 14816 python_module: special 14817 structured_delegate: special_bessel_j1.out 14818 variants: function 14819 tags: pointwise 14820 14821- func: special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 14822 dispatch: 14823 CPU, CUDA: special_bessel_j1_out 14824 python_module: special 14825 structured_inherits: TensorIteratorBase 14826 structured: True 14827 variants: function 14828 tags: pointwise 14829 14830- func: special_bessel_y0(Tensor self) -> Tensor 14831 python_module: special 14832 structured_delegate: special_bessel_y0.out 14833 variants: function 14834 tags: pointwise 14835 14836- func: special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 14837 dispatch: 14838 CPU, CUDA: special_bessel_y0_out 14839 python_module: special 14840 structured_inherits: TensorIteratorBase 14841 structured: True 14842 variants: function 14843 tags: pointwise 14844 14845- func: special_bessel_y1(Tensor self) -> Tensor 14846 python_module: special 14847 structured_delegate: special_bessel_y1.out 14848 variants: function 14849 tags: pointwise 14850 14851- func: special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 14852 dispatch: 14853 CPU, CUDA: special_bessel_y1_out 14854 python_module: special 14855 structured_inherits: TensorIteratorBase 14856 structured: True 14857 variants: function 14858 tags: pointwise 14859 14860- func: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor 14861 device_check: NoCheck 14862 python_module: special 14863 structured_delegate: special_chebyshev_polynomial_t.out 14864 variants: function 14865 tags: pointwise 14866 14867- func: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor 14868 dispatch: 14869 CompositeExplicitAutograd: special_chebyshev_polynomial_t 14870 device_check: NoCheck 14871 python_module: special 14872 variants: function 14873 tags: pointwise 14874 14875- func: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor 14876 dispatch: 14877 CompositeExplicitAutograd: special_chebyshev_polynomial_t 14878 device_check: NoCheck 14879 python_module: special 14880 variants: function 14881 tags: pointwise 14882 14883- func: special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 14884 device_check: NoCheck 14885 dispatch: 14886 CPU, CUDA: special_chebyshev_polynomial_t_out 14887 python_module: special 14888 structured_inherits: TensorIteratorBase 14889 structured: True 14890 variants: function 14891 tags: pointwise 14892 14893- func: special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 14894 dispatch: 14895 CompositeExplicitAutograd: special_chebyshev_polynomial_t_out 14896 device_check: NoCheck 14897 python_module: special 14898 variants: function 14899 tags: pointwise 14900 14901- func: special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 14902 dispatch: 14903 CompositeExplicitAutograd: special_chebyshev_polynomial_t_out 14904 device_check: NoCheck 14905 python_module: special 14906 variants: function 14907 tags: pointwise 14908 14909- func: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor 14910 device_check: NoCheck 14911 python_module: special 14912 structured_delegate: special_chebyshev_polynomial_u.out 14913 variants: function 14914 tags: pointwise 14915 14916- func: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor 14917 dispatch: 14918 CompositeExplicitAutograd: special_chebyshev_polynomial_u 14919 device_check: NoCheck 14920 python_module: special 14921 variants: function 14922 tags: pointwise 14923 14924- func: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor 14925 dispatch: 14926 CompositeExplicitAutograd: special_chebyshev_polynomial_u 14927 device_check: NoCheck 14928 python_module: special 14929 variants: function 14930 tags: pointwise 14931 14932- func: special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 14933 device_check: NoCheck 14934 dispatch: 14935 CPU, CUDA: special_chebyshev_polynomial_u_out 14936 python_module: special 14937 structured_inherits: TensorIteratorBase 14938 structured: True 14939 variants: function 14940 tags: pointwise 14941 14942- func: special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 14943 dispatch: 14944 CompositeExplicitAutograd: special_chebyshev_polynomial_u_out 14945 device_check: NoCheck 14946 python_module: special 14947 variants: function 14948 tags: pointwise 14949 14950- func: special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 14951 dispatch: 14952 CompositeExplicitAutograd: special_chebyshev_polynomial_u_out 14953 device_check: NoCheck 14954 python_module: special 14955 variants: function 14956 tags: pointwise 14957 14958- func: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor 14959 device_check: NoCheck 14960 python_module: special 14961 structured_delegate: special_chebyshev_polynomial_v.out 14962 variants: function 14963 tags: pointwise 14964 14965- func: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor 14966 dispatch: 14967 CompositeExplicitAutograd: special_chebyshev_polynomial_v 14968 device_check: NoCheck 14969 python_module: special 14970 variants: function 14971 tags: pointwise 14972 14973- func: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor 14974 dispatch: 14975 CompositeExplicitAutograd: special_chebyshev_polynomial_v 14976 device_check: NoCheck 14977 python_module: special 14978 variants: function 14979 tags: pointwise 14980 14981- func: special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 14982 device_check: NoCheck 14983 dispatch: 14984 CPU, CUDA: special_chebyshev_polynomial_v_out 14985 python_module: special 14986 structured_inherits: TensorIteratorBase 14987 structured: True 14988 variants: function 14989 tags: pointwise 14990 14991- func: special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 14992 dispatch: 14993 CompositeExplicitAutograd: special_chebyshev_polynomial_v_out 14994 device_check: NoCheck 14995 python_module: special 14996 variants: function 14997 tags: pointwise 14998 14999- func: special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15000 dispatch: 15001 CompositeExplicitAutograd: special_chebyshev_polynomial_v_out 15002 device_check: NoCheck 15003 python_module: special 15004 variants: function 15005 tags: pointwise 15006 15007- func: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor 15008 device_check: NoCheck 15009 python_module: special 15010 structured_delegate: special_chebyshev_polynomial_w.out 15011 variants: function 15012 tags: pointwise 15013 15014- func: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor 15015 dispatch: 15016 CompositeExplicitAutograd: special_chebyshev_polynomial_w 15017 device_check: NoCheck 15018 python_module: special 15019 variants: function 15020 tags: pointwise 15021 15022- func: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor 15023 dispatch: 15024 CompositeExplicitAutograd: special_chebyshev_polynomial_w 15025 device_check: NoCheck 15026 python_module: special 15027 variants: function 15028 tags: pointwise 15029 15030- func: special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15031 device_check: NoCheck 15032 dispatch: 15033 CPU, CUDA: special_chebyshev_polynomial_w_out 15034 python_module: special 15035 structured_inherits: TensorIteratorBase 15036 structured: True 15037 variants: function 15038 tags: pointwise 15039 15040- func: special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15041 dispatch: 15042 CompositeExplicitAutograd: special_chebyshev_polynomial_w_out 15043 device_check: NoCheck 15044 python_module: special 15045 variants: function 15046 tags: pointwise 15047 15048- func: special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15049 dispatch: 15050 CompositeExplicitAutograd: special_chebyshev_polynomial_w_out 15051 device_check: NoCheck 15052 python_module: special 15053 variants: function 15054 tags: pointwise 15055 15056- func: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor 15057 device_check: NoCheck 15058 python_module: special 15059 structured_delegate: special_hermite_polynomial_h.out 15060 variants: function 15061 tags: pointwise 15062 15063- func: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor 15064 dispatch: 15065 CompositeExplicitAutograd: special_hermite_polynomial_h 15066 device_check: NoCheck 15067 python_module: special 15068 variants: function 15069 tags: pointwise 15070 15071- func: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor 15072 dispatch: 15073 CompositeExplicitAutograd: special_hermite_polynomial_h 15074 device_check: NoCheck 15075 python_module: special 15076 variants: function 15077 tags: pointwise 15078 15079- func: special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15080 device_check: NoCheck 15081 dispatch: 15082 CPU, CUDA: special_hermite_polynomial_h_out 15083 python_module: special 15084 structured_inherits: TensorIteratorBase 15085 structured: True 15086 variants: function 15087 tags: pointwise 15088 15089- func: special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15090 dispatch: 15091 CompositeExplicitAutograd: special_hermite_polynomial_h_out 15092 device_check: NoCheck 15093 python_module: special 15094 variants: function 15095 tags: pointwise 15096 15097- func: special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15098 dispatch: 15099 CompositeExplicitAutograd: special_hermite_polynomial_h_out 15100 device_check: NoCheck 15101 python_module: special 15102 variants: function 15103 tags: pointwise 15104 15105- func: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor 15106 device_check: NoCheck 15107 python_module: special 15108 structured_delegate: special_hermite_polynomial_he.out 15109 variants: function 15110 tags: pointwise 15111 15112- func: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor 15113 dispatch: 15114 CompositeExplicitAutograd: special_hermite_polynomial_he 15115 device_check: NoCheck 15116 python_module: special 15117 variants: function 15118 tags: pointwise 15119 15120- func: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor 15121 dispatch: 15122 CompositeExplicitAutograd: special_hermite_polynomial_he 15123 device_check: NoCheck 15124 python_module: special 15125 variants: function 15126 tags: pointwise 15127 15128- func: special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15129 device_check: NoCheck 15130 dispatch: 15131 CPU, CUDA: special_hermite_polynomial_he_out 15132 python_module: special 15133 structured_inherits: TensorIteratorBase 15134 structured: True 15135 variants: function 15136 tags: pointwise 15137 15138- func: special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15139 dispatch: 15140 CompositeExplicitAutograd: special_hermite_polynomial_he_out 15141 device_check: NoCheck 15142 python_module: special 15143 variants: function 15144 tags: pointwise 15145 15146- func: special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15147 dispatch: 15148 CompositeExplicitAutograd: special_hermite_polynomial_he_out 15149 device_check: NoCheck 15150 python_module: special 15151 variants: function 15152 tags: pointwise 15153 15154- func: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor 15155 device_check: NoCheck 15156 python_module: special 15157 structured_delegate: special_laguerre_polynomial_l.out 15158 variants: function 15159 tags: pointwise 15160 15161- func: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor 15162 dispatch: 15163 CompositeExplicitAutograd: special_laguerre_polynomial_l 15164 device_check: NoCheck 15165 python_module: special 15166 variants: function 15167 tags: pointwise 15168 15169- func: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor 15170 dispatch: 15171 CompositeExplicitAutograd: special_laguerre_polynomial_l 15172 device_check: NoCheck 15173 python_module: special 15174 variants: function 15175 tags: pointwise 15176 15177- func: special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15178 device_check: NoCheck 15179 dispatch: 15180 CPU, CUDA: special_laguerre_polynomial_l_out 15181 python_module: special 15182 structured_inherits: TensorIteratorBase 15183 structured: True 15184 variants: function 15185 tags: pointwise 15186 15187- func: special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15188 dispatch: 15189 CompositeExplicitAutograd: special_laguerre_polynomial_l_out 15190 device_check: NoCheck 15191 python_module: special 15192 variants: function 15193 tags: pointwise 15194 15195- func: special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15196 dispatch: 15197 CompositeExplicitAutograd: special_laguerre_polynomial_l_out 15198 device_check: NoCheck 15199 python_module: special 15200 variants: function 15201 tags: pointwise 15202 15203- func: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor 15204 device_check: NoCheck 15205 python_module: special 15206 structured_delegate: special_legendre_polynomial_p.out 15207 variants: function 15208 tags: pointwise 15209 15210- func: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor 15211 dispatch: 15212 CompositeExplicitAutograd: special_legendre_polynomial_p 15213 device_check: NoCheck 15214 python_module: special 15215 variants: function 15216 tags: pointwise 15217 15218- func: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor 15219 dispatch: 15220 CompositeExplicitAutograd: special_legendre_polynomial_p 15221 device_check: NoCheck 15222 python_module: special 15223 variants: function 15224 tags: pointwise 15225 15226- func: special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15227 device_check: NoCheck 15228 dispatch: 15229 CPU, CUDA: special_legendre_polynomial_p_out 15230 python_module: special 15231 structured_inherits: TensorIteratorBase 15232 structured: True 15233 variants: function 15234 tags: pointwise 15235 15236- func: special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15237 dispatch: 15238 CompositeExplicitAutograd: special_legendre_polynomial_p_out 15239 device_check: NoCheck 15240 python_module: special 15241 variants: function 15242 tags: pointwise 15243 15244- func: special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15245 dispatch: 15246 CompositeExplicitAutograd: special_legendre_polynomial_p_out 15247 device_check: NoCheck 15248 python_module: special 15249 variants: function 15250 tags: pointwise 15251 15252- func: special_modified_bessel_i0(Tensor self) -> Tensor 15253 python_module: special 15254 structured_delegate: special_modified_bessel_i0.out 15255 variants: function 15256 tags: pointwise 15257 15258- func: special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 15259 dispatch: 15260 CPU, CUDA: special_modified_bessel_i0_out 15261 python_module: special 15262 structured_inherits: TensorIteratorBase 15263 structured: True 15264 variants: function 15265 tags: pointwise 15266 15267- func: special_modified_bessel_i1(Tensor self) -> Tensor 15268 python_module: special 15269 structured_delegate: special_modified_bessel_i1.out 15270 variants: function 15271 tags: pointwise 15272 15273- func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 15274 dispatch: 15275 CPU, CUDA: special_modified_bessel_i1_out 15276 python_module: special 15277 structured_inherits: TensorIteratorBase 15278 structured: True 15279 variants: function 15280 tags: pointwise 15281 15282- func: special_modified_bessel_k0(Tensor self) -> Tensor 15283 python_module: special 15284 structured_delegate: special_modified_bessel_k0.out 15285 variants: function 15286 tags: pointwise 15287 15288- func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 15289 dispatch: 15290 CPU, CUDA: special_modified_bessel_k0_out 15291 python_module: special 15292 structured_inherits: TensorIteratorBase 15293 structured: True 15294 variants: function 15295 tags: pointwise 15296 15297- func: special_modified_bessel_k1(Tensor self) -> Tensor 15298 python_module: special 15299 structured_delegate: special_modified_bessel_k1.out 15300 variants: function 15301 tags: pointwise 15302 15303- func: special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) 15304 dispatch: 15305 CPU, CUDA: special_modified_bessel_k1_out 15306 python_module: special 15307 structured_inherits: TensorIteratorBase 15308 structured: True 15309 variants: function 15310 tags: pointwise 15311 15312- func: special_scaled_modified_bessel_k0(Tensor x) -> Tensor 15313 python_module: special 15314 structured_delegate: special_scaled_modified_bessel_k0.out 15315 variants: function 15316 tags: pointwise 15317 15318- func: special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) 15319 dispatch: 15320 CPU, CUDA: special_scaled_modified_bessel_k0_out 15321 python_module: special 15322 structured_inherits: TensorIteratorBase 15323 structured: True 15324 variants: function 15325 tags: pointwise 15326 15327- func: special_scaled_modified_bessel_k1(Tensor x) -> Tensor 15328 python_module: special 15329 structured_delegate: special_scaled_modified_bessel_k1.out 15330 variants: function 15331 tags: pointwise 15332 15333- func: special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) 15334 dispatch: 15335 CPU, CUDA: special_scaled_modified_bessel_k1_out 15336 python_module: special 15337 structured_inherits: TensorIteratorBase 15338 structured: True 15339 variants: function 15340 tags: pointwise 15341 15342- func: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor 15343 device_check: NoCheck 15344 python_module: special 15345 structured_delegate: special_shifted_chebyshev_polynomial_t.out 15346 variants: function 15347 tags: pointwise 15348 15349- func: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor 15350 dispatch: 15351 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t 15352 device_check: NoCheck 15353 python_module: special 15354 variants: function 15355 tags: pointwise 15356 15357- func: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor 15358 dispatch: 15359 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t 15360 device_check: NoCheck 15361 python_module: special 15362 variants: function 15363 tags: pointwise 15364 15365- func: special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15366 device_check: NoCheck 15367 dispatch: 15368 CPU, CUDA: special_shifted_chebyshev_polynomial_t_out 15369 python_module: special 15370 structured_inherits: TensorIteratorBase 15371 structured: True 15372 variants: function 15373 tags: pointwise 15374 15375- func: special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15376 dispatch: 15377 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t_out 15378 device_check: NoCheck 15379 python_module: special 15380 variants: function 15381 tags: pointwise 15382 15383- func: special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15384 dispatch: 15385 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t_out 15386 device_check: NoCheck 15387 python_module: special 15388 variants: function 15389 tags: pointwise 15390 15391- func: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor 15392 device_check: NoCheck 15393 python_module: special 15394 structured_delegate: special_shifted_chebyshev_polynomial_u.out 15395 variants: function 15396 tags: pointwise 15397 15398- func: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor 15399 dispatch: 15400 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u 15401 device_check: NoCheck 15402 python_module: special 15403 variants: function 15404 tags: pointwise 15405 15406- func: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor 15407 dispatch: 15408 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u 15409 device_check: NoCheck 15410 python_module: special 15411 variants: function 15412 tags: pointwise 15413 15414- func: special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15415 device_check: NoCheck 15416 dispatch: 15417 CPU, CUDA: special_shifted_chebyshev_polynomial_u_out 15418 python_module: special 15419 structured_inherits: TensorIteratorBase 15420 structured: True 15421 variants: function 15422 tags: pointwise 15423 15424- func: special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15425 dispatch: 15426 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u_out 15427 device_check: NoCheck 15428 python_module: special 15429 variants: function 15430 tags: pointwise 15431 15432- func: special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15433 dispatch: 15434 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u_out 15435 device_check: NoCheck 15436 python_module: special 15437 variants: function 15438 tags: pointwise 15439 15440- func: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor 15441 device_check: NoCheck 15442 python_module: special 15443 structured_delegate: special_shifted_chebyshev_polynomial_v.out 15444 variants: function 15445 tags: pointwise 15446 15447- func: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor 15448 dispatch: 15449 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v 15450 device_check: NoCheck 15451 python_module: special 15452 variants: function 15453 tags: pointwise 15454 15455- func: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor 15456 dispatch: 15457 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v 15458 device_check: NoCheck 15459 python_module: special 15460 variants: function 15461 tags: pointwise 15462 15463- func: special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15464 device_check: NoCheck 15465 dispatch: 15466 CPU, CUDA: special_shifted_chebyshev_polynomial_v_out 15467 python_module: special 15468 structured_inherits: TensorIteratorBase 15469 structured: True 15470 variants: function 15471 tags: pointwise 15472 15473- func: special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15474 dispatch: 15475 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v_out 15476 device_check: NoCheck 15477 python_module: special 15478 variants: function 15479 tags: pointwise 15480 15481- func: special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15482 dispatch: 15483 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v_out 15484 device_check: NoCheck 15485 python_module: special 15486 variants: function 15487 tags: pointwise 15488 15489- func: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor 15490 device_check: NoCheck 15491 python_module: special 15492 structured_delegate: special_shifted_chebyshev_polynomial_w.out 15493 variants: function 15494 tags: pointwise 15495 15496- func: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor 15497 dispatch: 15498 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w 15499 device_check: NoCheck 15500 python_module: special 15501 variants: function 15502 tags: pointwise 15503 15504- func: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor 15505 dispatch: 15506 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w 15507 device_check: NoCheck 15508 python_module: special 15509 variants: function 15510 tags: pointwise 15511 15512- func: special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15513 device_check: NoCheck 15514 dispatch: 15515 CPU, CUDA: special_shifted_chebyshev_polynomial_w_out 15516 python_module: special 15517 structured_inherits: TensorIteratorBase 15518 structured: True 15519 variants: function 15520 tags: pointwise 15521 15522- func: special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) 15523 dispatch: 15524 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w_out 15525 device_check: NoCheck 15526 python_module: special 15527 variants: function 15528 tags: pointwise 15529 15530- func: special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) 15531 dispatch: 15532 CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w_out 15533 device_check: NoCheck 15534 python_module: special 15535 variants: function 15536 tags: pointwise 15537 15538- func: special_spherical_bessel_j0(Tensor x) -> Tensor 15539 python_module: special 15540 structured_delegate: special_spherical_bessel_j0.out 15541 variants: function 15542 tags: pointwise 15543 15544- func: special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) 15545 dispatch: 15546 CPU, CUDA: special_spherical_bessel_j0_out 15547 python_module: special 15548 structured_inherits: TensorIteratorBase 15549 structured: True 15550 variants: function 15551 tags: pointwise 15552 15553# Aux function used in the test TestPythonDispatch.test_kwarg_only_and_positional_default 15554# within test/test_python_dispatch.py 15555- func: _foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor 15556 dispatch: 15557 CPU: foobar 15558 autogen: _foobar.out 15559 15560- func: _fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () 15561 # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). 15562 variants: function 15563 dispatch: 15564 CPU: _fused_adam_kernel_cpu_ 15565 CUDA: _fused_adam_kernel_cuda_ 15566 autogen: _fused_adam, _fused_adam.out 15567 15568- func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () 15569 # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now), 15570 # but still skip the device check as the Tensor LR can be on CPU 15571 device_check: NoCheck 15572 variants: function 15573 dispatch: 15574 CPU: _fused_adam_kernel_cpu_ 15575 CUDA: _fused_adam_kernel_cuda_ 15576 autogen: _fused_adam.tensor_lr, _fused_adam.tensor_lr_out 15577 15578- func: _fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () 15579 # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). 15580 variants: function 15581 dispatch: 15582 CPU: _fused_adamw_kernel_cpu_ 15583 CUDA: _fused_adamw_kernel_cuda_ 15584 autogen: _fused_adamw, _fused_adamw.out 15585 15586- func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () 15587 # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now), 15588 # but still skip the device check as the Tensor LR can be on CPU 15589 device_check: NoCheck 15590 variants: function 15591 dispatch: 15592 CPU: _fused_adamw_kernel_cpu_ 15593 CUDA: _fused_adamw_kernel_cuda_ 15594 autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out 15595 15596- func: _fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () 15597 # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). 15598 variants: function 15599 dispatch: 15600 CPU: _fused_sgd_kernel_cpu_ 15601 CUDA: _fused_sgd_kernel_cuda_ 15602 autogen: _fused_sgd, _fused_sgd.out 15603 15604- func: _fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () 15605 # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). 15606 # but still skip the device check as the Tensor LR can be on CPU 15607 device_check: NoCheck 15608 variants: function 15609 dispatch: 15610 CPU: _fused_sgd_kernel_cpu_ 15611 CUDA: _fused_sgd_kernel_cuda_ 15612 autogen: _fused_sgd.tensor_lr, _fused_sgd.tensor_lr_out 15613 15614- func: _fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () 15615 variants: function 15616 dispatch: 15617 CPU: _fused_adagrad_kernel_cpu_ 15618 autogen: _fused_adagrad, _fused_adagrad.out 15619 15620# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts. 15621- func: _propagate_xla_data(Tensor input, Tensor output) -> () 15622 variants: function 15623