1""" 2Python implementation of ``__torch_function__`` 3 4While most of the torch API and handling for ``__torch_function__`` happens 5at the C++ level, some of the torch API is written in Python so we need 6python-level handling for ``__torch_function__`` overrides as well. The main 7developer-facing functionality in this file are handle_torch_function and 8has_torch_function. See torch/functional.py and test/test_overrides.py 9for usage examples. 10 11Note 12---- 13heavily inspired by NumPy's ``__array_function__`` (see: 14https://github.com/pytorch/pytorch/issues/24015 and 15https://www.numpy.org/neps/nep-0018-array-function-protocol.html 16) 17 18If changing this file in a way that can affect ``__torch_function__`` overhead, 19please report the benchmarks in ``benchmarks/overrides_benchmark``. See the 20instructions in the ``README.md`` in that directory. 21""" 22 23import __future__ # noqa: F404 24 25import collections 26import contextlib 27import functools 28import types 29import warnings 30from functools import wraps 31from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type 32 33import torch 34from torch._C import ( 35 _add_docstr, 36 _get_function_stack_at, 37 _has_torch_function, 38 _has_torch_function_unary, 39 _has_torch_function_variadic, 40 _is_torch_function_mode_enabled, 41 _len_torch_function_stack, 42 _pop_torch_function_stack, 43 _push_on_torch_function_stack, 44) 45 46 47__all__ = [ 48 "get_ignored_functions", 49 "get_overridable_functions", 50 "get_testing_overrides", 51 "handle_torch_function", 52 "has_torch_function", 53 "resolve_name", 54 "is_tensor_like", 55 "is_tensor_method_or_property", 56 "wrap_torch_function", 57 "enable_reentrant_dispatch", 58] 59 60 61def _disable_user_warnings( 62 func: Callable, 63 regex: str = ".*is deprecated, please use.*", 64 module: str = "torch", 65) -> Callable: 66 """ 67 Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the 68 given ``regex`` pattern. 69 70 Arguments 71 --------- 72 func : function 73 Function to disable the warnings for. 74 regex : str 75 A regex pattern compilable by ``re.compile``. This is used to match the ``UserWarning`` message. 76 module : str 77 The python module to which the filtering should be restricted. 78 79 Returns 80 ------- 81 function 82 The wrapped function. 83 """ 84 85 @wraps(func) 86 def wrapper(*args, **kwargs): 87 with warnings.catch_warnings(): 88 warnings.filterwarnings( 89 "ignore", category=UserWarning, message=regex, module=module 90 ) 91 return func(*args, **kwargs) 92 93 return wrapper 94 95 96@functools.lru_cache(None) 97@_disable_user_warnings 98def get_ignored_functions() -> Set[Callable]: 99 """ 100 Return public functions that cannot be overridden by ``__torch_function__``. 101 102 Returns 103 ------- 104 Set[Callable] 105 A tuple of functions that are publicly available in the torch API but cannot 106 be overridden with ``__torch_function__``. Mostly this is because none of the 107 arguments of these functions are tensors or tensor-likes. 108 109 Examples 110 -------- 111 >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() 112 True 113 >>> torch.add in torch.overrides.get_ignored_functions() 114 False 115 """ 116 Tensor = torch.Tensor 117 return { 118 torch.typename, 119 torch.is_tensor, 120 torch.is_storage, 121 torch.set_default_tensor_type, 122 torch.set_default_device, 123 torch.get_default_device, 124 torch.set_rng_state, 125 torch.get_rng_state, 126 torch.manual_seed, 127 torch.initial_seed, 128 torch.seed, 129 torch.save, 130 torch.load, 131 torch.set_printoptions, 132 torch.fork, 133 torch.get_default_dtype, 134 torch.get_num_interop_threads, 135 torch.get_num_threads, 136 torch.init_num_threads, 137 torch.import_ir_module, 138 torch.import_ir_module_from_buffer, 139 torch.is_anomaly_enabled, 140 torch.is_anomaly_check_nan_enabled, 141 torch.is_grad_enabled, 142 torch.merge_type_from_type_comment, 143 torch.parse_ir, 144 torch.parse_schema, 145 torch.parse_type_comment, 146 torch.set_anomaly_enabled, 147 torch.set_flush_denormal, 148 torch.set_num_interop_threads, 149 torch.set_num_threads, 150 torch.wait, 151 torch.as_tensor, 152 torch.from_numpy, 153 torch.get_device, 154 torch.tensor, 155 torch.default_generator, 156 torch.has_cuda, 157 torch.has_cudnn, 158 torch.has_lapack, 159 torch.device, 160 torch.dtype, 161 torch.finfo, 162 torch.has_mkl, 163 torch.has_mps, 164 torch.has_mkldnn, 165 torch.has_openmp, 166 torch.iinfo, 167 torch.memory_format, 168 torch.qscheme, 169 torch.set_grad_enabled, 170 torch.no_grad, 171 torch.enable_grad, 172 torch.inference_mode, 173 torch.is_inference_mode_enabled, 174 torch.layout, 175 torch.align_tensors, 176 torch.arange, 177 torch.as_strided, 178 torch.bartlett_window, 179 torch.blackman_window, 180 torch.broadcast_shapes, 181 torch.can_cast, 182 torch.compile, 183 torch.cudnn_affine_grid_generator, 184 torch.cudnn_batch_norm, 185 torch.cudnn_convolution, 186 torch.cudnn_convolution_transpose, 187 torch.cudnn_convolution_relu, 188 torch.cudnn_convolution_add_relu, 189 torch.cudnn_grid_sampler, 190 torch.cudnn_is_acceptable, 191 torch.empty, 192 torch.empty_permuted, 193 torch.empty_strided, 194 torch.empty_quantized, 195 torch.export.export, 196 torch.export.load, 197 torch.export.register_dataclass, 198 torch.export.save, 199 torch.eye, 200 torch.fft.fftfreq, 201 torch.fft.rfftfreq, 202 torch.from_file, 203 torch.full, 204 torch.fill, 205 torch.hamming_window, 206 torch.hann_window, 207 torch.kaiser_window, 208 torch.linspace, 209 torch.logspace, 210 torch.mkldnn_adaptive_avg_pool2d, 211 torch.mkldnn_convolution, 212 torch.mkldnn_max_pool2d, 213 torch.mkldnn_max_pool3d, 214 torch.mkldnn_linear_backward_weights, 215 torch.mkldnn_rnn_layer, 216 torch.normal, 217 torch.ones, 218 torch.promote_types, 219 torch.rand, 220 torch.randn, 221 torch.randint, 222 torch.randperm, 223 torch.range, 224 torch.result_type, 225 torch.scalar_tensor, 226 torch.sparse_coo_tensor, 227 torch.sparse_compressed_tensor, 228 torch.sparse_csr_tensor, 229 torch.sparse_csc_tensor, 230 torch.sparse_bsr_tensor, 231 torch.sparse_bsc_tensor, 232 torch.sym_constrain_range, 233 torch.sym_constrain_range_for_size, 234 torch.tril_indices, 235 torch.triu_indices, 236 torch.vander, 237 torch.zeros, 238 torch._jit_internal.boolean_dispatch, 239 torch.nn.functional.assert_int_or_pair, 240 torch.nn.functional.upsample, 241 torch.nn.functional.upsample_bilinear, 242 torch.nn.functional.upsample_nearest, 243 torch.nn.functional.has_torch_function, 244 torch.nn.functional.has_torch_function_unary, 245 torch.nn.functional.has_torch_function_variadic, 246 torch.nn.functional.handle_torch_function, 247 torch.nn.functional.sigmoid, 248 torch.nn.functional.hardsigmoid, 249 torch.nn.functional.tanh, 250 torch.nn.functional._canonical_mask, 251 torch.nn.functional._none_or_dtype, 252 # Doesn't actually take or return tensor arguments 253 torch.nn.init.calculate_gain, 254 # These are deprecated; don't test them 255 torch.nn.init.uniform, 256 torch.nn.init.normal, 257 torch.nn.init.constant, 258 torch.nn.init.eye, 259 torch.nn.init.dirac, 260 torch.nn.init.xavier_uniform, 261 torch.nn.init.xavier_normal, 262 torch.nn.init.kaiming_uniform, 263 torch.nn.init.kaiming_normal, 264 torch.nn.init.orthogonal, 265 torch.nn.init.sparse, 266 torch.nested.to_padded_tensor, 267 has_torch_function, 268 handle_torch_function, 269 torch.set_autocast_enabled, 270 torch.is_autocast_enabled, 271 torch.set_autocast_dtype, 272 torch.get_autocast_dtype, 273 torch.clear_autocast_cache, 274 torch.set_autocast_cpu_enabled, 275 torch.is_autocast_cpu_enabled, 276 torch.set_autocast_xla_enabled, 277 torch.is_autocast_xla_enabled, 278 torch.set_autocast_ipu_enabled, 279 torch.is_autocast_ipu_enabled, 280 torch.set_autocast_cpu_dtype, 281 torch.get_autocast_cpu_dtype, 282 torch.set_autocast_ipu_dtype, 283 torch.get_autocast_ipu_dtype, 284 torch.get_autocast_gpu_dtype, 285 torch.set_autocast_gpu_dtype, 286 torch.get_autocast_xla_dtype, 287 torch.set_autocast_xla_dtype, 288 torch.autocast_increment_nesting, 289 torch.autocast_decrement_nesting, 290 torch.is_autocast_cache_enabled, 291 torch.set_autocast_cache_enabled, 292 torch.nn.functional.hardswish, 293 torch.is_vulkan_available, 294 torch.are_deterministic_algorithms_enabled, 295 torch.use_deterministic_algorithms, 296 torch.is_deterministic_algorithms_warn_only_enabled, 297 torch.set_deterministic_debug_mode, 298 torch.get_device_module, 299 torch.get_deterministic_debug_mode, 300 torch.set_float32_matmul_precision, 301 torch.get_float32_matmul_precision, 302 torch.unify_type_list, 303 torch.is_warn_always_enabled, 304 torch.set_warn_always, 305 torch.vitals_enabled, 306 torch.set_vital, 307 torch.read_vitals, 308 torch.vmap, 309 torch.cond, 310 torch.frombuffer, 311 torch.asarray, 312 torch._functional_sym_constrain_range, 313 torch._make_dep_token, 314 Tensor.__delitem__, 315 Tensor.__dir__, 316 Tensor.__getattribute__, 317 Tensor.__init__, 318 Tensor.__iter__, 319 Tensor.__init_subclass__, 320 Tensor.__delattr__, 321 Tensor.__setattr__, 322 Tensor.__torch_function__, 323 Tensor.__torch_dispatch__, 324 Tensor.__new__, 325 Tensor.__class__, 326 Tensor.__subclasshook__, 327 Tensor.__hash__, 328 Tensor.as_subclass, 329 Tensor.eig, 330 Tensor.lstsq, 331 Tensor.reinforce, 332 Tensor.new, 333 Tensor.new_tensor, 334 Tensor.new_empty, 335 Tensor.new_empty_strided, 336 Tensor.new_zeros, 337 Tensor.new_ones, 338 Tensor.new_full, 339 Tensor._make_subclass, 340 Tensor.solve, 341 Tensor.symeig, 342 Tensor.stride, 343 Tensor.unflatten, 344 Tensor.to_sparse_coo, 345 Tensor.to_sparse_csr, 346 Tensor.to_sparse_csc, 347 Tensor.to_sparse_bsr, 348 Tensor.to_sparse_bsc, 349 Tensor._to_sparse, 350 Tensor._to_sparse_csr, 351 Tensor._to_sparse_csc, 352 Tensor._to_sparse_bsr, 353 Tensor._to_sparse_bsc, 354 Tensor._typed_storage, 355 Tensor._reduce_ex_internal, 356 Tensor._fix_weakref, 357 Tensor._view_func, 358 Tensor._view_func_unsafe, 359 Tensor._rev_view_func_unsafe, 360 Tensor._make_wrapper_subclass, 361 Tensor._python_dispatch.__get__, 362 Tensor._has_symbolic_sizes_strides.__get__, 363 Tensor._conj, 364 Tensor._conj_physical, 365 Tensor._lazy_clone, 366 Tensor._neg_view, 367 Tensor._is_zerotensor, 368 Tensor._is_all_true, 369 Tensor._is_any_true, 370 Tensor._addmm_activation, 371 Tensor.to_padded_tensor, 372 Tensor._use_count, 373 } 374 375 376@functools.lru_cache(None) 377def get_default_nowrap_functions() -> Set[Callable]: 378 """ 379 Return public functions that do not wrap in a subclass when invoked by 380 the default ``Tensor.__torch_function__`` that preserves subclasses. Typically, 381 these functions represent field accesses (i.e., retrieving a Tensor that 382 is stored somewhere on the Tensor) as opposed to computation. Users of 383 these functions expect object identity to be preserved over multiple accesses 384 (e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on 385 the fly every time (furthermore, the tensor stored here might already be 386 the subclass, in which case wrapping really ought not to happen). 387 388 Not ALL property accessors have this property; for example ``Tensor.T`` actually 389 just creates a new transposed tensor on the fly, and so we SHOULD interpose on 390 these calls (you need to check the implementation of the function to see if 391 this is the case or not). Additionally, if a property accessor doesn't return a Tensor, 392 it doesn't have to be on this list (though it is harmless if it is). 393 """ 394 Tensor = torch.Tensor 395 return { 396 Tensor._base.__get__, 397 Tensor.grad.__get__, 398 Tensor._grad.__get__, 399 } 400 401 402@functools.lru_cache(None) 403@_disable_user_warnings 404def get_testing_overrides() -> Dict[Callable, Callable]: 405 """Return a dict containing dummy overrides for all overridable functions 406 407 Returns 408 ------- 409 Dict[Callable, Callable] 410 A dictionary that maps overridable functions in the PyTorch API to 411 lambda functions that have the same signature as the real function 412 and unconditionally return -1. These lambda functions are useful 413 for testing API coverage for a type that defines ``__torch_function__``. 414 415 Examples 416 -------- 417 >>> import inspect 418 >>> my_add = torch.overrides.get_testing_overrides()[torch.add] 419 >>> inspect.signature(my_add) 420 <Signature (input, other, out=None)> 421 """ 422 # Every function in the PyTorchAPI that can be overriden needs an entry 423 # in this dict. 424 # 425 # Optimally we would use inspect to get the function signature and define 426 # the lambda function procedurally but that is blocked by generating 427 # function signatures for native kernels that can be consumed by inspect. 428 # See Issue #28233. 429 Tensor = torch.Tensor 430 ret: Dict[Callable, Callable] = { 431 torch.abs: lambda input, out=None: -1, 432 torch.absolute: lambda input, out=None: -1, 433 torch.adaptive_avg_pool1d: lambda input, output_size: -1, 434 torch.adaptive_max_pool1d: lambda inputs, output_size: -1, 435 torch.acos: lambda input, out=None: -1, 436 torch.adjoint: lambda input: -1, 437 torch.arccos: lambda input, out=None: -1, 438 torch.acosh: lambda input, out=None: -1, 439 torch.arccosh: lambda input, out=None: -1, 440 torch.add: lambda input, other, out=None: -1, 441 torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, 442 torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1, 443 torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1, 444 torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, 445 torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1, 446 torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1, 447 torch.affine_grid_generator: lambda theta, size, align_corners: -1, 448 torch.all: lambda input, dim=None: -1, 449 torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1, 450 torch.alpha_dropout: lambda input, p, train, inplace=False: -1, 451 torch.amax: lambda input, dim=None: -1, 452 torch.amin: lambda input, dim=None: -1, 453 torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1, 454 torch.angle: lambda input, out=None: -1, 455 torch.any: lambda input, dim=None, keepdim=False, out=None: -1, 456 torch.argmax: lambda input: -1, 457 torch.argmin: lambda input: -1, 458 torch.argsort: lambda input, dim=None: -1, 459 torch.asin: lambda input, out=None: -1, 460 torch._assert_async: lambda input, msg: -1, 461 torch.arcsin: lambda input, out=None: -1, 462 torch.asinh: lambda input, out=None: -1, 463 torch.arcsinh: lambda input, out=None: -1, 464 torch.atan: lambda input, out=None: -1, 465 torch.arctan: lambda input, out=None: -1, 466 torch.atan2: lambda input, other, out=None: -1, 467 torch.arctan2: lambda input, other, out=None: -1, 468 torch.atanh: lambda input, out=None: -1, 469 torch.arctanh: lambda input, out=None: -1, 470 torch.atleast_1d: lambda *tensors: -1, 471 torch.atleast_2d: lambda *tensors: -1, 472 torch.atleast_3d: lambda *tensors: -1, 473 torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1, 474 torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1, 475 torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1, 476 torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1, 477 torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1, 478 torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1, 479 torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1, 480 torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1, 481 torch.batch_norm_stats: lambda input, eps: -1, 482 torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1, 483 torch.bernoulli: lambda input, generator=None, out=None: -1, 484 torch.bilinear: lambda input1, input2, weight, bias: -1, 485 torch.binary_cross_entropy_with_logits: ( 486 lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1 487 ), 488 torch.bincount: lambda input, weights=None, minlength=0: -1, 489 torch.binomial: lambda count, prob, generator=None: -1, 490 torch.bitwise_and: lambda input, other, out=None: -1, 491 torch.bitwise_not: lambda input, out=None: -1, 492 torch.bitwise_or: lambda input, other, out=None: -1, 493 torch.bitwise_xor: lambda input, other, out=None: -1, 494 torch.bitwise_left_shift: lambda input, other, out=None: -1, 495 torch.bitwise_right_shift: lambda input, other, out=None: -1, 496 torch.block_diag: lambda *tensors: -1, 497 torch.bmm: lambda input, mat2, out=None: -1, 498 torch.broadcast_tensors: lambda *tensors: -1, 499 torch.broadcast_to: lambda self, size: -1, 500 torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1, 501 torch.cartesian_prod: lambda *tensors: -1, 502 torch.cat: lambda tensors, dim=0, out=None: -1, 503 torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat 504 torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate 505 torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1, 506 torch.ceil: lambda input, out=None: -1, 507 torch.celu: lambda input, alpha=1.0, inplace=False: -1, 508 torch.chain_matmul: lambda *matrices, out=None: -1, 509 torch.channel_shuffle: lambda input, groups: -1, 510 torch.cholesky: lambda input, upper=False, out=None: -1, 511 torch.linalg.cholesky: lambda input, out=None: -1, 512 torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1, 513 torch.cholesky_inverse: lambda input, upper=False, out=None: -1, 514 torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1, 515 torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1, 516 torch.chunk: lambda input, chunks, dim=0: -1, 517 torch.clamp: lambda input, min=None, max=None, out=None: -1, 518 torch.clip: lambda input, min=None, max=None, out=None: -1, 519 torch.clamp_min: lambda input, min, out=None: -1, 520 torch.clamp_max: lambda input, max, out=None: -1, 521 torch.column_stack: lambda tensors, out=None: -1, 522 torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1, 523 torch.clone: lambda input: -1, 524 torch.combinations: lambda input, r=2, with_replacement=False: -1, 525 torch.complex: lambda real, imag: -1, 526 torch.copysign: lambda input, other, out=None: -1, 527 torch.polar: lambda abs, ang: -1, 528 torch.linalg.cond: lambda input, ord=None: -1, 529 torch.conj: lambda input, out=None: -1, 530 torch.conj_physical: lambda input, out=None: -1, 531 torch.resolve_conj: lambda input, out=None: -1, 532 torch.resolve_neg: lambda input, out=None: -1, 533 torch.constant_pad_nd: lambda input, pad, value=0: -1, 534 torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, 535 torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, 536 torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1, 537 torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1, 538 torch.conv_tbc: lambda input, weight, bias, pad=0: -1, 539 torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, 540 torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, 541 torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1, 542 torch.corrcoef: lambda input: -1, 543 torch.cos: lambda input, out=None: -1, 544 torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, 545 torch.cosh: lambda input, out=None: -1, 546 torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1, 547 torch.count_nonzero: lambda input: -1, 548 torch.cross: lambda input, other, dim=None, out=None: -1, 549 torch.linalg.cross: lambda input, other, dim=-1, out=None: -1, 550 torch.ctc_loss: ( 551 lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1 552 ), 553 torch.cummax: lambda input, dim, out=None: -1, 554 torch.cummin: lambda input, dim, out=None: -1, 555 torch.cumprod: lambda input, dim, out=None, dtype=None: -1, 556 torch.cumsum: lambda input, dim, out=None, dtype=None: -1, 557 torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1, 558 torch.logcumsumexp: lambda input, dim, out=None: -1, 559 torch.deg2rad: lambda input, out=None: -1, 560 torch.dequantize: lambda input: -1, 561 torch.det: lambda input: -1, 562 torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined] 563 torch.detach: lambda input: -1, 564 torch.diag: lambda input, diagonal=0, out=None: -1, 565 torch.diag_embed: lambda input, diagonal=0, out=None: -1, 566 torch.diagflat: lambda input, offset=0: -1, 567 torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1, 568 torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1, 569 torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1, 570 torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1, 571 torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1, 572 torch.digamma: lambda input, out=None: -1, 573 torch.dist: lambda input, other, p=2: -1, 574 torch.div: lambda input, other, rounding_mode=None, out=None: -1, 575 torch.divide: lambda input, other, rounding_mode=None, out=None: -1, 576 torch.dot: lambda input, other, out=None: -1, 577 torch.dropout: lambda input, p, train, inplace=False: -1, 578 torch.dsmm: lambda input, mat2: -1, 579 torch.hsmm: lambda mat1, mat2: -1, 580 torch.dsplit: lambda input, indices_or_sections: -1, 581 torch.dstack: lambda tensors, out=None: -1, 582 torch.linalg.eig: lambda input, out=None: -1, 583 torch.linalg.eigvals: lambda input, out=None: -1, 584 torch.linalg.eigh: lambda input, UPLO="L", out=None: -1, 585 torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1, 586 torch.einsum: lambda equation, *operands: -1, 587 torch.embedding: ( 588 lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950 589 ), 590 torch.embedding_bag: ( 591 lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950 592 ), 593 torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, 594 torch.eq: lambda input, other, out=None: -1, 595 torch.equal: lambda input, other: -1, 596 torch.erf: lambda input, out=None: -1, 597 torch.erfc: lambda input, out=None: -1, 598 torch.erfinv: lambda input, out=None: -1, 599 torch.exp: lambda input, out=None: -1, 600 torch.exp2: lambda input, out=None: -1, 601 torch.expm1: lambda input, out=None: -1, 602 torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1, 603 torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1, 604 torch.fused_moving_avg_obs_fake_quant: ( 605 lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950 606 ), 607 torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1, 608 torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1, 609 torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950 610 torch.fbgemm_linear_int8_weight_fp32_activation: ( 611 lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1 612 ), 613 torch.fbgemm_linear_quantize_weight: lambda input: -1, 614 torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1, 615 torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1, 616 torch.feature_alpha_dropout: lambda input, p, train: -1, 617 torch.feature_dropout: lambda input, p, train: -1, 618 torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1, 619 torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1, 620 torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1, 621 torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1, 622 torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1, 623 torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, 624 torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, 625 torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1, 626 torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1, 627 torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1, 628 torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1, 629 torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1, 630 torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1, 631 torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, 632 torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, 633 torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, 634 torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1, 635 torch.fft.fftshift: lambda input, dim=None: -1, 636 torch.fft.ifftshift: lambda input, dim=None: -1, 637 torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1, 638 torch.fix: lambda input, out=None: -1, 639 torch.flatten: lambda input, start_dim=0, end_dim=-1: -1, 640 torch.flip: lambda input, dims: -1, 641 torch.fliplr: lambda input: -1, 642 torch.flipud: lambda input: -1, 643 torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1, 644 torch.floor: lambda input, out=None: -1, 645 torch.floor_divide: lambda input, other: -1, 646 torch.float_power: lambda input, exponent, out=None: -1, 647 torch.fmod: lambda input, other, out=None: -1, 648 torch.frac: lambda input, out=None: -1, 649 torch.frexp: lambda input, out=None: -1, 650 torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950 651 torch._functional_assert_async: lambda input, msg, dep_token: -1, 652 torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1, 653 torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1, 654 torch.gcd: lambda input, other, out=None: -1, 655 torch.ge: lambda input, other, out=None: -1, 656 torch.greater_equal: lambda input, other, out=None: -1, 657 torch.geqrf: lambda input, out=None: -1, 658 torch.i0: lambda input, out=None: -1, 659 torch.inner: lambda input, other, out=None: -1, 660 torch.outer: lambda input, vec2, out=None: -1, 661 torch.ger: lambda input, vec2, out=None: -1, # alias for torch.outer 662 torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1, 663 torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, 664 torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, 665 torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1, 666 torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1, 667 torch.gru: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, 668 torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, 669 torch.gt: lambda input, other, out=None: -1, 670 torch.greater: lambda input, other, out=None: -1, 671 torch.hardshrink: lambda input, lambd=0.5: -1, 672 torch.heaviside: lambda input, values, out=None: -1, 673 torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950 674 torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1, 675 torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1, 676 torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1, 677 torch.linalg.householder_product: lambda input, tau: -1, 678 torch.hspmm: lambda mat1, mat2, out=None: -1, 679 torch.hsplit: lambda input, indices_or_sections: -1, 680 torch.hstack: lambda tensors, out=None: -1, 681 torch.hypot: lambda input, other, out=None: -1, 682 torch.igamma: lambda input, other, out=None: -1, 683 torch.igammac: lambda input, other, out=None: -1, 684 torch.imag: lambda input, out=None: -1, 685 torch.index_add: lambda input, dim, index, source: -1, 686 torch.index_copy: lambda input, dim, index, source: -1, 687 torch.index_put: lambda input, indices, values, accumulate=False: -1, 688 torch.index_select: lambda input, dim, index, out=None: -1, 689 torch.index_fill: lambda input, dim, index, value: -1, 690 torch.index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1, 691 torch.isfinite: lambda tensor: -1, 692 torch.isin: lambda e, te, assume_unique=False, invert=False: -1, 693 torch.isinf: lambda tensor: -1, 694 torch.isreal: lambda tensor: -1, 695 torch.isposinf: lambda input, out=None: -1, 696 torch.isneginf: lambda input, out=None: -1, 697 torch.instance_norm: ( 698 lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1 699 ), 700 torch.int_repr: lambda input: -1, 701 torch.inverse: lambda input, out=None: -1, 702 torch.linalg.inv: lambda input, out=None: -1, 703 torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1, 704 torch.is_complex: lambda input: -1, 705 torch.is_conj: lambda input: -1, 706 torch.is_neg: lambda input: -1, 707 torch.is_distributed: lambda input: -1, 708 torch.is_inference: lambda input: -1, 709 torch.is_floating_point: lambda input: -1, 710 torch.is_nonzero: lambda input: -1, 711 torch.is_same_size: lambda input, other: -1, 712 torch.is_signed: lambda input: -1, 713 torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1, 714 torch.isnan: lambda input: -1, 715 torch.istft: ( 716 lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950 717 ), 718 torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, 719 torch.kron: lambda input, other: -1, 720 torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1, 721 torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1, 722 torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1, 723 torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1, 724 torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1, 725 torch.lcm: lambda input, other, out=None: -1, 726 torch.ldexp: lambda input, other, out=None: -1, 727 torch.le: lambda input, other, out=None: -1, 728 torch.less_equal: lambda input, other, out=None: -1, 729 torch.lerp: lambda input, end, weight, out=None: -1, 730 torch.lgamma: lambda input, out=None: -1, 731 torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950 732 torch.log: lambda input, out=None: -1, 733 torch.log_softmax: lambda input, dim, dtype=None: -1, 734 torch.log10: lambda input, out=None: -1, 735 torch.log1p: lambda input, out=None: -1, 736 torch.log2: lambda input, out=None: -1, 737 torch.logaddexp: lambda input, other, out=None: -1, 738 torch.logaddexp2: lambda input, other, out=None: -1, 739 torch.logdet: lambda input: -1, 740 torch.xlogy: lambda x, y, out=None: -1, 741 torch.logical_and: lambda input, other, out=None: -1, 742 torch.logical_not: lambda input, out=None: -1, 743 torch.logical_or: lambda input, other, out=None: -1, 744 torch.logical_xor: lambda input, other, out=None: -1, 745 torch.logit: lambda input, eps=None: -1, 746 torch.logsumexp: lambda input, names, keepdim=False, out=None: -1, 747 torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1, 748 torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, 749 torch.lt: lambda input, other, out=None: -1, 750 torch.less: lambda input, other, out=None: -1, 751 torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, 752 torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1, 753 torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950 754 torch.masked_fill: lambda input, mask, value: -1, 755 torch.masked_scatter: lambda input, mask, source: -1, 756 torch.masked_select: lambda input, mask, out=None: -1, 757 torch.matmul: lambda input, other, out=None: -1, 758 torch.linalg.lu: lambda input, pivot=True, out=None: -1, 759 torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1, 760 torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1, 761 torch.linalg.lu_solve: lambda LU, pivots, B, left=True, adjoint=False, out=None: -1, 762 torch.linalg.matmul: lambda input, other, out=None: -1, # alias for torch.matmul 763 torch.matrix_power: lambda input, n: -1, 764 torch.linalg.matrix_power: lambda input, n, out=None: -1, 765 torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1, 766 torch.linalg.multi_dot: lambda tensors, out=None: -1, 767 torch.matrix_exp: lambda input: -1, 768 torch.linalg.matrix_exp: lambda input: -1, 769 torch.max: lambda input, out=None: -1, 770 torch.maximum: lambda input, other, out=None: -1, 771 torch.fmax: lambda input, other, out=None: -1, 772 torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, 773 torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, 774 torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1, 775 torch.max_pool1d_with_indices: ( 776 lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 777 ), 778 torch.mean: lambda input, dim=None: -1, 779 torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1, 780 torch.median: lambda input, dim=None: -1, 781 torch.nanmedian: lambda input, dim=None: -1, 782 torch.meshgrid: lambda *tensors, **kwargs: -1, 783 torch.min: lambda input, out=None: -1, 784 torch.minimum: lambda input, other, out=None: -1, 785 torch.fmin: lambda input, other, out=None: -1, 786 torch.miopen_batch_norm: ( 787 lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1 788 ), 789 torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950 790 torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1, 791 torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1, 792 torch.miopen_convolution_transpose: ( 793 lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1 794 ), 795 torch.miopen_depthwise_convolution: ( 796 lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1 797 ), 798 torch.miopen_rnn: ( 799 lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950 800 ), 801 torch.mm: lambda input, mat2, out=None: -1, 802 torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1, 803 torch.movedim: lambda input, source, destination: -1, 804 torch.moveaxis: lambda input, source, destination: -1, 805 torch.msort: lambda input, descending=False, out=None: -1, 806 torch.mul: lambda input, other, out=None: -1, 807 torch.multiply: lambda input, other, out=None: -1, 808 torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1, 809 torch.mv: lambda input, vec, out=None: -1, 810 torch.mvlgamma: lambda input, p: -1, 811 torch.narrow: lambda input, dim, start, length: -1, 812 torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1, 813 torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1, 814 torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1, 815 torch.native_dropout: lambda input, p, train: -1, 816 torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, 817 torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1, 818 torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1, 819 torch.native_channel_shuffle: lambda input, groups: -1, 820 torch.ne: lambda input, other, out=None: -1, 821 torch.not_equal: lambda input, other, out=None: -1, 822 torch.neg: lambda input, out=None: -1, 823 torch.negative: lambda input, out=None: -1, 824 torch.nextafter: lambda input, other, out=None: -1, 825 torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1, 826 torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1, 827 torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1, 828 torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1, 829 torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1, 830 torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1, 831 torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1, 832 torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1, 833 torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1, 834 torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, 835 torch.nn.functional.avg_pool2d: ( 836 lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950 837 ), 838 torch.nn.functional.avg_pool3d: ( 839 lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950 840 ), 841 torch.nn.functional.batch_norm: ( 842 lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1 843 ), 844 torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1, 845 torch.nn.functional.binary_cross_entropy: ( 846 lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1 847 ), 848 torch.nn.functional.binary_cross_entropy_with_logits: ( 849 lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1 850 ), 851 torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1, 852 torch.nn.functional.cosine_embedding_loss: ( 853 lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1 854 ), 855 torch.nn.functional.cross_entropy: ( 856 lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950 857 ), 858 torch.nn.functional.ctc_loss: ( 859 lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1 860 ), 861 torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1, 862 torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1, 863 torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1, 864 torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1, 865 torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1, 866 torch.nn.functional.embedding: ( 867 lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950 868 ), 869 torch.nn.functional.embedding_bag: ( 870 lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950 871 ), 872 torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, 873 torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1, 874 torch.nn.functional.fractional_max_pool2d: ( 875 lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 876 ), 877 torch.nn.functional.fractional_max_pool2d_with_indices: ( 878 lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 879 ), 880 torch.nn.functional.fractional_max_pool3d: ( 881 lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 882 ), 883 torch.nn.functional.fractional_max_pool3d_with_indices: ( 884 lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 885 ), 886 torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1, 887 torch.nn.functional.gelu: lambda input, approximate="none": -1, 888 torch.nn.functional.glu: lambda input, dim=-1: -1, 889 torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950 890 torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1, 891 torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1, 892 torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1, 893 torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1, 894 torch.nn.functional.hinge_embedding_loss: ( 895 lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1 896 ), 897 torch.nn.functional.instance_norm: ( 898 lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950 899 ), 900 torch.nn.functional.interpolate: ( 901 lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950 902 ), 903 torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950 904 torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, 905 torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, 906 torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1, 907 torch.nn.functional.linear: lambda input, weight, bias=None: -1, 908 torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1, 909 torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, 910 torch.nn.functional.logsigmoid: lambda input: -1, 911 torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, 912 torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, 913 torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1, 914 torch.nn.functional.margin_ranking_loss: ( 915 lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1 916 ), 917 torch.nn.functional.max_pool1d: ( 918 lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1 919 ), 920 torch.nn.functional.max_pool1d_with_indices: ( 921 lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 922 ), 923 torch.nn.functional.max_pool2d: ( 924 lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1 925 ), 926 torch.nn.functional.max_pool2d_with_indices: ( 927 lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 928 ), 929 torch.nn.functional.max_pool3d: ( 930 lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 931 ), 932 torch.nn.functional.max_pool3d_with_indices: ( 933 lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 934 ), 935 torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 936 torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 937 torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 938 torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, 939 torch.nn.functional.multi_head_attention_forward: ( 940 lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950 941 ), 942 torch.nn.functional.multi_margin_loss: ( 943 lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1 944 ), 945 torch.nn.functional.multilabel_margin_loss: ( 946 lambda input, target, size_average=None, reduce=None, reduction="mean": -1 947 ), 948 torch.nn.functional.multilabel_soft_margin_loss: ( 949 lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1 950 ), 951 torch.nn.functional.nll_loss: ( 952 lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1 953 ), 954 torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1, 955 torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1, 956 torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1, 957 torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1, 958 torch.nn.functional.poisson_nll_loss: ( 959 lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950 960 ), 961 torch.nn.functional.prelu: lambda input, weight: -1, 962 torch.nn.functional.relu: lambda input, inplace=False: -1, 963 torch.nn.functional.relu6: lambda input, inplace=False: -1, 964 torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, 965 torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950 966 torch.nn.functional.selu: lambda input, inplace=False: -1, 967 torch.nn.functional.silu: lambda input, inplace=False: -1, 968 torch.nn.functional.mish: lambda input, inplace=False: -1, 969 torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1, 970 torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950 971 torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1, 972 torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950 973 torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, 974 torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1, 975 torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1, 976 torch.nn.functional.softshrink: lambda input, lambd=0.5: -1, 977 torch.nn.functional.softsign: lambda input: -1, 978 torch.nn.functional.tanhshrink: lambda input: -1, 979 torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1, 980 torch.nn.functional.triplet_margin_loss: ( 981 lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950 982 ), 983 torch.nn.functional.triplet_margin_with_distance_loss: ( 984 lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1 985 ), 986 torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1, 987 torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1, 988 torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1, 989 torch.nn.init.constant_: lambda tensor, val: -1, 990 torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950 991 torch.nonzero: lambda input, as_tuple=False: -1, 992 torch.nonzero_static: lambda input, *, size, fill_value=-1: -1, 993 torch.argwhere: lambda input: -1, 994 torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1, 995 torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1, 996 torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1, 997 torch.linalg.matrix_norm: lambda input, ord="fro", dim=( 998 -2, 999 -1, 1000 ), keepdim=False, out=None, dtype=None: -1, 1001 torch.norm_except_dim: lambda v, pow=2, dim=0: -1, 1002 torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1, 1003 torch.numel: lambda input: -1, 1004 torch.orgqr: lambda input, tau: -1, 1005 torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1, 1006 torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1, 1007 torch.permute: lambda self, dim: -1, 1008 torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1, 1009 torch.pdist: lambda input, p=2: -1, 1010 torch.pinverse: lambda input, rcond=1e-15: -1, 1011 torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1, 1012 torch.pixel_shuffle: lambda input, upscale_factor: -1, 1013 torch.pixel_unshuffle: lambda input, downscale_factor: -1, 1014 torch.poisson: lambda input, generator=None: -1, 1015 torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1, 1016 torch.polygamma: lambda input, n, out=None: -1, 1017 torch.positive: lambda input, out=None: -1, 1018 torch.prelu: lambda input, weight: -1, 1019 torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, 1020 torch.pow: lambda input, exponent, out=None: -1, 1021 torch.prod: lambda input, dtype=None: -1, 1022 torch.put: lambda input, index, source, accumulate=False: -1, 1023 torch.q_per_channel_axis: lambda input: -1, 1024 torch.q_per_channel_scales: lambda input: -1, 1025 torch.q_per_channel_zero_points: lambda input: -1, 1026 torch.q_scale: lambda input: -1, 1027 torch.q_zero_point: lambda input: -1, 1028 torch.qr: lambda input, some=True, out=None: -1, 1029 torch.linalg.qr: lambda input, mode="reduced", out=None: -1, 1030 torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1, 1031 torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1, 1032 torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1, 1033 torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1, 1034 torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1, 1035 torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1, 1036 torch.quantized_gru_cell: ( 1037 lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 1038 ), 1039 torch.quantized_lstm_cell: ( 1040 lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 1041 ), 1042 torch.quantized_max_pool1d: ( 1043 lambda input, kernel_size, stride=(), padding=(0,), dilation=( 1044 1, 1045 ), ceil_mode=False: -1 1046 ), 1047 torch.quantized_max_pool2d: ( 1048 lambda input, kernel_size, stride=(), padding=(0, 0), dilation=( 1049 1, 1050 1, 1051 ), ceil_mode=False: -1 1052 ), 1053 torch.quantized_max_pool3d: ( 1054 lambda input, kernel_size, stride=(), padding=(0, 0, 0), dilation=( 1055 1, 1056 1, 1057 1, 1058 ), ceil_mode=False: -1 1059 ), 1060 torch.quantized_rnn_relu_cell: ( 1061 lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 1062 ), 1063 torch.quantized_rnn_tanh_cell: ( 1064 lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 1065 ), 1066 torch.rad2deg: lambda input, out=None: -1, 1067 torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, 1068 torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, 1069 torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, 1070 torch.ravel: lambda input: -1, 1071 torch.real: lambda input, out=None: -1, 1072 torch.vdot: lambda input, other, out=None: -1, 1073 torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1, 1074 torch.view_as_real: lambda input: -1, 1075 torch.view_as_complex: lambda input: -1, 1076 torch.reciprocal: lambda input, out=None: -1, 1077 torch.relu: lambda input, inplace=False: -1, 1078 torch.remainder: lambda input, other, out=None: -1, 1079 torch.renorm: lambda input, p, dim, maxnorm, out=None: -1, 1080 torch.repeat_interleave: lambda input, dim=None: -1, 1081 torch.reshape: lambda input, shape: -1, 1082 torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, 1083 torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950 1084 torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, 1085 torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950 1086 torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, 1087 torch.roll: lambda input, shifts, dims=None: -1, 1088 torch.rot90: lambda input, k=1, dims=(0, 1): -1, 1089 torch.round: lambda input, out=None: -1, 1090 torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack 1091 torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1), 1092 torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1, 1093 torch.rsqrt: lambda input, out=None: -1, 1094 torch.rsub: lambda input, other, alpha=1: -1, 1095 torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, 1096 torch.scatter: lambda input, dim, index, src: -1, 1097 torch.scatter_add: lambda input, dim, index, src: -1, 1098 torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1, 1099 torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, 1100 torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950 1101 torch.select: lambda input, dim, index: -1, 1102 torch.select_scatter: lambda input, src, dim, index: -1, 1103 torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1, 1104 torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1, 1105 torch.selu: lambda input, inplace=False: -1, 1106 torch.sigmoid: lambda input, out=None: -1, 1107 torch.sign: lambda input, out=None: -1, 1108 torch.signbit: lambda input, out=None: -1, 1109 torch.sgn: lambda input, out=None: -1, 1110 torch.sin: lambda input, out=None: -1, 1111 torch.sinc: lambda input, out=None: -1, 1112 torch.sinh: lambda input, out=None: -1, 1113 torch.slogdet: lambda input: -1, 1114 torch.linalg.slogdet: lambda input: -1, 1115 torch.smm: lambda input, mat2: -1, 1116 torch.spmm: lambda input, mat2: -1, 1117 torch.softmax: lambda input, dim, dtype=None: -1, 1118 torch.linalg.solve: lambda A, B, left=True, out=None: -1, 1119 torch.linalg.solve_ex: lambda A, B, left=True, check_errors=False, out=None: -1, 1120 torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1, 1121 torch.split: lambda tensor, split_size_or_sections, dim=0: -1, 1122 torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1, 1123 torch.sqrt: lambda input, out=None: -1, 1124 torch.square: lambda input, out=None: -1, 1125 torch.squeeze: lambda input, dim=None, out=None: -1, 1126 torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, 1127 torch.stack: lambda tensors, dim=0, out=None: -1, 1128 torch.std: lambda input, dim=None: -1, 1129 torch.std_mean: lambda input, dim=None: -1, 1130 torch.stft: ( 1131 lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950 1132 ), 1133 torch.sub: lambda input, other, out=None: -1, 1134 torch.subtract: lambda input, other, out=None: -1, 1135 torch.sum: lambda input, dim=None: -1, 1136 torch.sym_float: lambda input: -1, 1137 torch.sym_int: lambda input: -1, 1138 torch.sym_max: lambda a, b: -1, 1139 torch.sym_min: lambda a, b: -1, 1140 torch.sym_not: lambda input: -1, 1141 torch.sym_ite: lambda a, b, c: -1, 1142 torch._sym_sqrt: lambda input: -1, 1143 torch._sym_cos: lambda input: -1, 1144 torch._sym_cosh: lambda input: -1, 1145 torch._sym_sin: lambda input: -1, 1146 torch._sym_sinh: lambda input: -1, 1147 torch._sym_tan: lambda input: -1, 1148 torch._sym_tanh: lambda input: -1, 1149 torch._sym_asin: lambda input: -1, 1150 torch._sym_acos: lambda input: -1, 1151 torch._sym_atan: lambda input: -1, 1152 torch.nansum: lambda input, dim=None: -1, 1153 torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, 1154 torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1, 1155 torch.linalg.svd: lambda input, full_matrices=True, out=None: -1, 1156 torch.linalg.svdvals: lambda input, out=None: -1, 1157 torch.swapaxes: lambda input, dim0, dim1: -1, 1158 torch.swapdims: lambda input, axis0, axis1: -1, 1159 torch.special.airy_ai: lambda input: -1, 1160 torch.special.bessel_j0: lambda input: -1, 1161 torch.special.bessel_j1: lambda input: -1, 1162 torch.special.bessel_y0: lambda input: -1, 1163 torch.special.bessel_y1: lambda input: -1, 1164 torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1, 1165 torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1, 1166 torch.special.chebyshev_polynomial_v: lambda input, n, out=None: -1, 1167 torch.special.chebyshev_polynomial_w: lambda input, n, out=None: -1, 1168 torch.special.digamma: lambda input: -1, 1169 torch.special.entr: lambda input: -1, 1170 torch.special.erf: lambda input: -1, 1171 torch.special.erfc: lambda input: -1, 1172 torch.special.erfcx: lambda input: -1, 1173 torch.special.erfinv: lambda input: -1, 1174 torch.special.exp2: lambda input: -1, 1175 torch.special.expit: lambda input: -1, 1176 torch.special.expm1: lambda input: -1, 1177 torch.special.gammainc: lambda input, other, out=None: -1, 1178 torch.special.gammaincc: lambda input, other, out=None: -1, 1179 torch.special.gammaln: lambda input: -1, 1180 torch.special.hermite_polynomial_h: lambda input, n, out=None: -1, 1181 torch.special.hermite_polynomial_he: lambda input, n, out=None: -1, 1182 torch.special.i0: lambda input: -1, 1183 torch.special.i0e: lambda input: -1, 1184 torch.special.i1: lambda input: -1, 1185 torch.special.i1e: lambda input: -1, 1186 torch.special.laguerre_polynomial_l: lambda input, n, out=None: -1, 1187 torch.special.legendre_polynomial_p: lambda input, n, out=None: -1, 1188 torch.special.log1p: lambda input: -1, 1189 torch.special.log_ndtr: lambda input: -1, 1190 torch.special.log_softmax: lambda input, dim, dtype=None: -1, 1191 torch.special.logit: lambda input: -1, 1192 torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1, 1193 torch.special.modified_bessel_i0: lambda input: -1, 1194 torch.special.modified_bessel_i1: lambda input: -1, 1195 torch.special.modified_bessel_k0: lambda input: -1, 1196 torch.special.modified_bessel_k1: lambda input: -1, 1197 torch.special.multigammaln: lambda input, p: -1, 1198 torch.special.ndtr: lambda input: -1, 1199 torch.special.ndtri: lambda input: -1, 1200 torch.special.polygamma: lambda input, n, out=None: -1, 1201 torch.special.psi: lambda input: -1, 1202 torch.special.round: lambda input: -1, 1203 torch.special.scaled_modified_bessel_k0: lambda input: -1, 1204 torch.special.scaled_modified_bessel_k1: lambda input: -1, 1205 torch.special.shifted_chebyshev_polynomial_t: lambda input, n, out=None: -1, 1206 torch.special.shifted_chebyshev_polynomial_u: lambda input, n, out=None: -1, 1207 torch.special.shifted_chebyshev_polynomial_v: lambda input, n, out=None: -1, 1208 torch.special.shifted_chebyshev_polynomial_w: lambda input, n, out=None: -1, 1209 torch.special.sinc: lambda input: -1, 1210 torch.special.softmax: lambda input, dim, dtype=None: -1, 1211 torch.special.spherical_bessel_j0: lambda input: -1, 1212 torch.special.xlog1py: lambda input, other, out=None: -1, 1213 torch.special.xlogy: lambda input, other, out=None: -1, 1214 torch.special.zeta: lambda self, other, out=None: -1, 1215 torch.t: lambda input: -1, 1216 torch.take: lambda input, index: -1, 1217 torch.take_along_dim: lambda input, indices, dim=None, out=None: -1, 1218 torch.tan: lambda input, out=None: -1, 1219 torch.tanh: lambda input, out=None: -1, 1220 torch.linalg.tensorinv: lambda a, ind=2: -1, 1221 torch.linalg.tensorsolve: lambda a, b, dims=None: -1, 1222 torch.tensordot: lambda a, b, dims=2, out=None: -1, 1223 torch.tensor_split: lambda input, indices_or_sections, dim=0: -1, 1224 torch.threshold: lambda input, threshold, value, inplace=False: -1, 1225 torch.tile: lambda input, dims: -1, 1226 torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1, 1227 torch.trace: lambda input: -1, 1228 torch.transpose: lambda input, dim0, dim1: -1, 1229 torch.trapz: lambda y, x=None, dim=-1: -1, 1230 torch.trapezoid: lambda y, x=None, dim=-1: -1, 1231 torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1, 1232 torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1, 1233 torch.tril: lambda input, diagonal=0, out=None: -1, 1234 torch.triplet_margin_loss: ( 1235 lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950 1236 ), 1237 torch.triu: lambda input, diagonal=0, out=None: -1, 1238 torch.true_divide: lambda input, other: -1, 1239 torch.trunc: lambda input, out=None: -1, 1240 torch.unbind: lambda input, dim=0: -1, 1241 torch.unflatten: lambda input, dim, sizes, names: -1, 1242 torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1, 1243 torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1, 1244 torch.unravel_index: lambda indices, shape: -1, 1245 torch.unsafe_chunk: lambda input, chunks, dim=0: -1, 1246 torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1, 1247 torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1, 1248 torch.unsqueeze: lambda input, dim, out=None: -1, 1249 torch.linalg.vander: lambda x, N=None: -1, 1250 torch.var: lambda input, dim=None: -1, 1251 torch.var_mean: lambda input, dim=None: -1, 1252 torch.vsplit: lambda input, indices_or_sections: -1, 1253 torch.vstack: lambda tensors, out=None: -1, 1254 torch.where: lambda condition, x=None, y=None: -1, 1255 torch._wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1, 1256 torch._wrapped_quantized_linear_prepacked: ( 1257 lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950 1258 ), 1259 torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, 1260 torch._fw_primal_copy: lambda self, level: -1, 1261 torch._make_dual_copy: lambda primal, tangent, level: -1, 1262 torch.view_as_real_copy: lambda self: -1, 1263 torch.view_as_complex_copy: lambda self: -1, 1264 torch._conj_copy: lambda self: -1, 1265 torch._neg_view_copy: lambda self: -1, 1266 torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1, 1267 torch._sparse_broadcast_to_copy: lambda self, size: -1, 1268 torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1, 1269 torch.expand_copy: lambda self, size, *, implicit=False: -1, 1270 torch.narrow_copy: lambda self, dim, start, length: -1, 1271 torch.permute_copy: lambda self, dims: -1, 1272 torch._reshape_alias_copy: lambda self, size, stride: -1, 1273 torch.select_copy: lambda self, dim, index: -1, 1274 torch.detach_copy: lambda self: -1, 1275 torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1, 1276 torch.split_copy: lambda self, split_size, dim=0: -1, 1277 torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1, 1278 torch.squeeze_copy: lambda self, dim: -1, 1279 torch.t_copy: lambda self: -1, 1280 torch.transpose_copy: lambda self, dim0, dim1: -1, 1281 torch.unsqueeze_copy: lambda self, dim: -1, 1282 torch._indices_copy: lambda self: -1, 1283 torch._values_copy: lambda self: -1, 1284 torch.indices_copy: lambda self: -1, 1285 torch.values_copy: lambda self: -1, 1286 torch.crow_indices_copy: lambda self: -1, 1287 torch.col_indices_copy: lambda self: -1, 1288 torch.ccol_indices_copy: lambda self: -1, 1289 torch.row_indices_copy: lambda self: -1, 1290 torch.unbind_copy: lambda self, dim=0: -1, 1291 torch.view_copy: lambda self, dtype: -1, 1292 torch.unfold_copy: lambda self, dimension, size, step: -1, 1293 torch.alias_copy: lambda self: -1, 1294 Tensor.__floordiv__: lambda self, other: -1, 1295 Tensor.__rfloordiv__: lambda self, other: -1, 1296 Tensor.__ifloordiv__: lambda self, other: -1, 1297 Tensor.__truediv__: lambda self, other: -1, 1298 Tensor.__rtruediv__: lambda self, other: -1, 1299 Tensor.__itruediv__: lambda self, other: -1, 1300 Tensor.__lshift__: lambda self, other: -1, 1301 Tensor.__rlshift__: lambda self, other: -1, 1302 Tensor.__ilshift__: lambda self, other: -1, 1303 Tensor.__rshift__: lambda self, other: -1, 1304 Tensor.__rrshift__: lambda self, other: -1, 1305 Tensor.__irshift__: lambda self, other: -1, 1306 Tensor.__and__: lambda self, other: -1, 1307 Tensor.__or__: lambda self, other: -1, 1308 Tensor.__xor__: lambda self, other: -1, 1309 Tensor.__float__: lambda self: -1, 1310 Tensor.__complex__: lambda self: -1, 1311 Tensor.__array__: lambda self, dtype: -1, 1312 Tensor.__bool__: lambda self: -1, 1313 Tensor.__contains__: lambda self, other: -1, 1314 Tensor.__neg__: lambda self: -1, 1315 Tensor.__invert__: lambda self: -1, 1316 Tensor.__mod__: lambda self, other: -1, 1317 Tensor.__rmod__: lambda self, other: -1, 1318 Tensor.__imod__: lambda self, other: -1, 1319 Tensor.__array_wrap__: lambda self, array: -1, 1320 Tensor.__getitem__: lambda self, idx: -1, 1321 Tensor.__deepcopy__: lambda self, memo: -1, 1322 Tensor.__int__: lambda self: -1, 1323 Tensor.__long__: lambda self: -1, 1324 Tensor.__index__: lambda self: -1, 1325 Tensor.__len__: lambda self: -1, 1326 Tensor.__format__: lambda self, format_spec: -1, 1327 Tensor.__reduce_ex__: lambda self, proto: -1, 1328 Tensor.__reversed__: lambda self: -1, 1329 Tensor.__repr__: lambda self, *, tensor_contents=None: -1, 1330 Tensor.__setitem__: lambda self, k, v: -1, 1331 Tensor.__setstate__: lambda self, d: -1, 1332 Tensor.T.__get__: lambda self: -1, 1333 Tensor.H.__get__: lambda self: -1, 1334 Tensor.mT.__get__: lambda self: -1, 1335 Tensor.mH.__get__: lambda self: -1, 1336 Tensor._backward_hooks.__get__: lambda self: -1, 1337 Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1, 1338 Tensor._base.__get__: lambda self: -1, 1339 Tensor._cdata.__get__: lambda self: -1, 1340 Tensor.grad.__get__: lambda self: -1, 1341 Tensor._grad.__get__: lambda self: -1, 1342 Tensor._grad_fn.__get__: lambda self: -1, 1343 Tensor.grad_fn.__get__: lambda self: -1, 1344 Tensor._version.__get__: lambda self: -1, 1345 Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1, 1346 Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1, 1347 Tensor.data.__get__: lambda self: -1, 1348 Tensor.device.__get__: lambda self: -1, 1349 Tensor.dtype.__get__: lambda self: -1, 1350 Tensor.is_cuda.__get__: lambda self: -1, 1351 Tensor.is_cpu.__get__: lambda self: -1, 1352 Tensor.is_xla.__get__: lambda self: -1, 1353 Tensor.is_xpu.__get__: lambda self: -1, 1354 Tensor.is_ipu.__get__: lambda self: -1, 1355 Tensor.is_leaf.__get__: lambda self: -1, 1356 Tensor.retains_grad.__get__: lambda self: -1, 1357 Tensor.is_meta.__get__: lambda self: -1, 1358 Tensor.is_mps.__get__: lambda self: -1, 1359 Tensor.is_mtia.__get__: lambda self: -1, 1360 Tensor.is_nested.__get__: lambda self: -1, 1361 Tensor.is_maia.__get__: lambda self: -1, 1362 Tensor.is_mkldnn.__get__: lambda self: -1, 1363 Tensor.is_quantized.__get__: lambda self: -1, 1364 Tensor.is_sparse.__get__: lambda self: -1, 1365 Tensor.is_sparse_csr.__get__: lambda self: -1, 1366 Tensor.is_vulkan.__get__: lambda self: -1, 1367 Tensor.itemsize.__get__: lambda self: -1, 1368 Tensor.layout.__get__: lambda self: -1, 1369 Tensor.name.__get__: lambda self: -1, 1370 Tensor.names.__get__: lambda self: -1, 1371 Tensor.nbytes.__get__: lambda self: -1, 1372 Tensor.ndim.__get__: lambda self: -1, 1373 Tensor.output_nr.__get__: lambda self: -1, 1374 Tensor.requires_grad.__get__: lambda self: -1, 1375 Tensor.shape.__get__: lambda self: -1, 1376 Tensor.volatile.__get__: lambda self: -1, 1377 Tensor.real.__get__: lambda self: -1, 1378 Tensor.imag.__get__: lambda self: -1, 1379 Tensor.__cuda_array_interface__.__get__: lambda self: -1, 1380 Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1, 1381 Tensor._dimI: lambda self: -1, 1382 Tensor._dimV: lambda self: -1, 1383 Tensor._indices: lambda self: -1, 1384 Tensor._is_view: lambda self: -1, 1385 Tensor._nnz: lambda self: -1, 1386 Tensor.crow_indices: lambda self: -1, 1387 Tensor.col_indices: lambda self: -1, 1388 Tensor.ccol_indices: lambda self: -1, 1389 Tensor.row_indices: lambda self: -1, 1390 Tensor._update_names: lambda self, names, inplace: -1, 1391 Tensor._values: lambda self: -1, 1392 Tensor.adjoint: lambda self: -1, 1393 Tensor.align_as: lambda self, other: -1, 1394 Tensor.align_to: lambda self, order, ellipsis_idx: -1, 1395 Tensor.apply_: lambda self, callable: -1, 1396 Tensor.as_strided: lambda self, size, stride: -1, 1397 Tensor.as_strided_: lambda self, size, stride: -1, 1398 Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1, 1399 Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1, 1400 Tensor.bool: lambda self, memory_format=torch.preserve_format: -1, 1401 Tensor.byte: lambda self, memory_format=torch.preserve_format: -1, 1402 Tensor.char: lambda self, memory_format=torch.preserve_format: -1, 1403 Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1, 1404 Tensor.coalesce: lambda self: -1, 1405 Tensor._coalesced_: lambda self, coalesced: -1, 1406 Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1, 1407 Tensor.copy_: lambda self, src, non_blocking=False: -1, 1408 Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1, 1409 Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1, 1410 Tensor.mtia: lambda self, memory_format=torch.preserve_format: -1, 1411 Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1, 1412 Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1, 1413 Tensor.data_ptr: lambda self: -1, 1414 Tensor.dense_dim: lambda self: -1, 1415 Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1, 1416 Tensor.dim: lambda self: -1, 1417 Tensor.dim_order: lambda self: -1, 1418 Tensor.double: lambda self, memory_format=torch.preserve_format: -1, 1419 Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1, 1420 Tensor.element_size: lambda self: -1, 1421 Tensor.expand: lambda self, size: -1, 1422 Tensor.expand_as: lambda self, other: -1, 1423 Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1, 1424 Tensor.fill_: lambda self, value: -1, 1425 Tensor.fill_diagonal_: lambda self, value: -1, 1426 Tensor.float: lambda self, memory_format=torch.preserve_format: -1, 1427 Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1, 1428 Tensor.geometric_: lambda self, p, *, generator=None: -1, 1429 Tensor.get_device: lambda self: -1, 1430 Tensor.half: lambda self, memory_format=torch.preserve_format: -1, 1431 Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1, 1432 Tensor.has_names: lambda self: -1, 1433 Tensor.indices: lambda self: -1, 1434 Tensor.int: lambda self, memory_format=torch.preserve_format: -1, 1435 Tensor.is_coalesced: lambda self: -1, 1436 Tensor.is_contiguous: lambda self: -1, 1437 Tensor.is_inference: lambda self: -1, 1438 Tensor.is_pinned: lambda self: -1, 1439 Tensor.is_set_to: lambda self, tensor: -1, 1440 Tensor.is_shared: lambda self: -1, 1441 Tensor.item: lambda self: -1, 1442 Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1, 1443 Tensor.log_softmax: lambda self, dim: -1, 1444 Tensor.long: lambda self, memory_format=torch.preserve_format: -1, 1445 Tensor.map_: lambda self, tensor, callable: -1, 1446 Tensor.map2_: lambda self, x, y, callable: -1, 1447 Tensor.mm: lambda self, mat2: -1, 1448 Tensor.module_load: lambda self, other, assign=False: -1, 1449 Tensor.narrow_copy: lambda self, dimension, start, length: -1, 1450 Tensor.ndimension: lambda self: -1, 1451 Tensor.nelement: lambda self: -1, 1452 Tensor._nested_tensor_size: lambda self: -1, 1453 Tensor._nested_tensor_storage_offsets: lambda self: -1, 1454 Tensor._nested_tensor_strides: lambda self: -1, 1455 Tensor.normal_: lambda self: -1, 1456 Tensor.numpy: lambda self: -1, 1457 Tensor.permute: lambda self, dim: -1, 1458 Tensor.pin_memory: lambda self: -1, 1459 Tensor.put_: lambda self, indices, tensor, accumulate=False: -1, 1460 Tensor.qscheme: lambda self: -1, 1461 Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1, 1462 Tensor.record_stream: lambda self, stream: -1, 1463 Tensor.refine_names: lambda self, names: -1, 1464 Tensor.register_hook: lambda self, hook: -1, 1465 Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1, 1466 Tensor.rename: lambda self, name: -1, 1467 Tensor.repeat: lambda self, *size: -1, 1468 Tensor.requires_grad_: lambda self, requires_grad=True: -1, 1469 Tensor.reshape_as: lambda self, other: -1, 1470 Tensor.resize: lambda self, *size: -1, 1471 Tensor.resize_: lambda self, size: -1, 1472 Tensor.resize_as: lambda self, other: -1, 1473 Tensor.resize_as_sparse_: lambda self, other: -1, 1474 Tensor.retain_grad: lambda self: -1, 1475 Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1, 1476 Tensor.select_scatter: lambda self, src, dim, index: -1, 1477 Tensor.share_memory_: lambda self: -1, 1478 Tensor.short: lambda self, memory_format=torch.preserve_format: -1, 1479 Tensor.size: lambda self: -1, 1480 Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1, 1481 Tensor.sparse_dim: lambda self: -1, 1482 Tensor.sparse_mask: lambda self, mask: -1, 1483 Tensor._sparse_mask_projection: lambda self, mask, accumulate_matches=False: -1, 1484 Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1, 1485 Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1, 1486 Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1, 1487 Tensor.storage: lambda self: -1, 1488 Tensor.untyped_storage: lambda self: -1, 1489 Tensor.storage_offset: lambda self: -1, 1490 Tensor.storage_type: lambda self: -1, 1491 Tensor.sum_to_size: lambda self, size: -1, 1492 Tensor.tile: lambda self, *reps: -1, 1493 Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1, 1494 Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1, 1495 Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1, 1496 Tensor.to_sparse: lambda self: -1, 1497 Tensor.tolist: lambda self: -1, 1498 Tensor.to_mkldnn: lambda self: -1, 1499 Tensor.type_as: lambda self, other: -1, 1500 Tensor.unfold: lambda self, dimension, size, step: -1, 1501 Tensor.uniform_: lambda self, from_=0, to=1: -1, 1502 Tensor.values: lambda self: -1, 1503 Tensor.view: lambda self, shape: -1, 1504 Tensor.view_as: lambda self, other: -1, 1505 Tensor.zero_: lambda self: -1, 1506 Tensor.__dlpack__: lambda self, stream=None: -1, 1507 Tensor.__dlpack_device__: lambda self: -1, 1508 torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, 1509 } # fmt: skip 1510 1511 privateuse1_backend_name = ( 1512 torch.utils.backend_registration._privateuse1_backend_name 1513 ) 1514 if hasattr(Tensor, privateuse1_backend_name): 1515 ret[getattr(Tensor, privateuse1_backend_name)] = ( 1516 lambda self, device=None, non_blocking=False, **kwargs: -1 1517 ) 1518 ret[getattr(Tensor, f"is_{privateuse1_backend_name}").__get__] = lambda self: -1 1519 1520 ret2 = {} 1521 ignored = get_ignored_functions() 1522 1523 for k, v in ret.items(): 1524 # Generate methods like __add__ and add_ by default from add 1525 names = [ 1526 k.__name__, # Default method 1527 k.__name__ + "_", # Inplace variant 1528 "__" + k.__name__ + "__", # Dunder method 1529 "__i" + k.__name__ + "__", # Inplace dunder method 1530 "__r" + k.__name__ + "__", # Reverse dunder method 1531 ] 1532 1533 if k.__name__.startswith("bitwise_"): 1534 # bitwise_<op> have dunder methods of the form __<op>__ 1535 # And so on. 1536 subname = k.__name__[len("bitwise_") :] 1537 names.extend( 1538 ["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"] 1539 ) 1540 1541 for name in names: 1542 func = getattr(Tensor, name, None) 1543 if callable(func) and func not in ret and func not in ignored: 1544 ret2[func] = v 1545 1546 ret.update(ret2) 1547 return ret 1548 1549 1550def wrap_torch_function(dispatcher: Callable): 1551 """Wraps a given function with ``__torch_function__`` -related functionality. 1552 1553 Parameters 1554 ---------- 1555 dispatcher: Callable 1556 A callable that returns an iterable of Tensor-likes passed into the function. 1557 1558 Note 1559 ---- 1560 This decorator may reduce the performance of your code. Generally, it's enough to express 1561 your code as a series of functions that, themselves, support __torch_function__. If you 1562 find yourself in the rare situation where this is not the case, e.g. if you're wrapping a 1563 low-level library and you also need it to work for Tensor-likes, then this function is available. 1564 1565 Examples 1566 -------- 1567 >>> def dispatcher(a): # Must have the same signature as func 1568 ... return (a,) 1569 >>> @torch.overrides.wrap_torch_function(dispatcher) 1570 >>> def func(a): # This will make func dispatchable by __torch_function__ 1571 ... return a + 0 1572 """ 1573 1574 def inner(func): 1575 @functools.wraps(func) 1576 def wrapped(*args, **kwargs): 1577 relevant_args = dispatcher(*args, **kwargs) 1578 if has_torch_function(relevant_args): 1579 return handle_torch_function(wrapped, relevant_args, *args, **kwargs) 1580 1581 return func(*args, **kwargs) 1582 1583 return wrapped 1584 1585 return inner 1586 1587 1588def _get_overloaded_args( 1589 relevant_args: Iterable[Any], 1590 get_type_fn: Callable[[Any], Type] = None, 1591) -> List[Any]: 1592 """Returns a list of arguments on which to call __torch_function__. 1593 1594 Checks arguments in relevant_args for __torch_function__ implementations, 1595 storing references to the arguments and their types in overloaded_args and 1596 overloaded_types in order of calling precedence. Only distinct types are 1597 considered. If a type is a subclass of another type it will have higher 1598 precedence, otherwise the precedence order is the same as the order of 1599 arguments in relevant_args, that is, from left-to-right in the argument list. 1600 1601 The precedence-determining algorithm implemented in this function is 1602 described in `NEP-0018`_. 1603 1604 See torch::append_overloaded_arg for the equivalent function in the C++ 1605 implementation. 1606 1607 Parameters 1608 ---------- 1609 relevant_args : iterable of array-like 1610 Iterable of array-like arguments to check for __torch_function__ 1611 methods. 1612 1613 get_type_fn : callable, optional 1614 Function to call on each argument in relevant_args to get its type. 1615 1616 Returns 1617 ------- 1618 overloaded_args : list 1619 Arguments from relevant_args on which to call __torch_function__ 1620 methods, in the order in which they should be called. 1621 1622 .. _NEP-0018: 1623 https://numpy.org/neps/nep-0018-array-function-protocol.html 1624 """ 1625 if get_type_fn is None: 1626 get_type_fn = type 1627 1628 # If torch function is not enabled, there are no overloaded types 1629 if not torch._C._is_torch_function_enabled(): 1630 return [] 1631 # Runtime is O(num_arguments * num_unique_types) 1632 overloaded_types: Set[Type] = set() 1633 overloaded_args: List[Any] = [] 1634 for arg in relevant_args: 1635 arg_type = get_type_fn(arg) 1636 # We only collect arguments if they have a unique type, which ensures 1637 # reasonable performance even with a long list of possibly overloaded 1638 # arguments. 1639 # 1640 # NB: Important to exclude _disabled_torch_function_impl, otherwise 1641 # https://github.com/pytorch/pytorch/issues/64687 1642 if ( 1643 arg_type not in overloaded_types 1644 and hasattr(arg_type, "__torch_function__") 1645 and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl 1646 ): 1647 # Create lists explicitly for the first type (usually the only one 1648 # done) to avoid setting up the iterator for overloaded_args. 1649 if overloaded_types: 1650 overloaded_types.add(arg_type) 1651 # By default, insert argument at the end, but if it is 1652 # subclass of another argument, insert it before that argument. 1653 # This ensures "subclasses before superclasses". 1654 index = len(overloaded_args) 1655 for i, old_arg in enumerate(overloaded_args): 1656 if issubclass(arg_type, get_type_fn(old_arg)): 1657 index = i 1658 break 1659 overloaded_args.insert(index, arg) 1660 else: 1661 overloaded_types = {arg_type} 1662 overloaded_args = [arg] 1663 return overloaded_args 1664 1665 1666def handle_torch_function( 1667 public_api: Callable, 1668 relevant_args: Iterable[Any], 1669 *args, 1670 **kwargs, 1671) -> Any: 1672 """Implement a function with checks for ``__torch_function__`` overrides. 1673 1674 See torch::autograd::handle_torch_function for the equivalent of this 1675 function in the C++ implementation. 1676 1677 Arguments 1678 --------- 1679 public_api : function 1680 Function exposed by the public torch API originally called like 1681 ``public_api(*args, **kwargs)`` on which arguments are now being 1682 checked. 1683 relevant_args : iterable 1684 Iterable of arguments to check for __torch_function__ methods. 1685 args : tuple 1686 Arbitrary positional arguments originally passed into ``public_api``. 1687 kwargs : tuple 1688 Arbitrary keyword arguments originally passed into ``public_api``. 1689 1690 Returns 1691 ------- 1692 object 1693 Result from calling ``implementation`` or an ``__torch_function__`` 1694 method, as appropriate. 1695 1696 Raises 1697 ------ 1698 TypeError : if no implementation is found. 1699 1700 Example 1701 ------- 1702 >>> def func(a): 1703 ... if has_torch_function_unary(a): 1704 ... return handle_torch_function(func, (a,), a) 1705 ... return a + 0 1706 """ 1707 # Check for __torch_function__ methods. 1708 overloaded_args = _get_overloaded_args(relevant_args) 1709 # overloaded_args already have unique types. 1710 types = tuple(map(type, overloaded_args)) 1711 1712 # Check for __torch_function__ mode. 1713 if _is_torch_function_mode_enabled(): 1714 # if we're here, the mode must be set to a TorchFunctionStackMode 1715 # this unsets it and calls directly into TorchFunctionStackMode's torch function 1716 with _pop_mode_temporarily() as mode: 1717 result = mode.__torch_function__(public_api, types, args, kwargs) 1718 if result is not NotImplemented: 1719 return result 1720 1721 # Call overrides 1722 for overloaded_arg in overloaded_args: 1723 # This call needs to become a classmethod call in the future. 1724 # See https://github.com/pytorch/pytorch/issues/63767 1725 torch_func_method = overloaded_arg.__torch_function__ 1726 if ( 1727 hasattr(torch_func_method, "__self__") 1728 and torch_func_method.__self__ is overloaded_arg 1729 and torch_func_method is not torch._C._disabled_torch_function_impl 1730 ): 1731 warnings.warn( 1732 "Defining your `__torch_function__ as a plain method is deprecated and " 1733 "will be an error in future, please define it as a classmethod.", 1734 DeprecationWarning, 1735 ) 1736 1737 # Use `public_api` instead of `implementation` so __torch_function__ 1738 # implementations can do equality/identity comparisons. 1739 result = torch_func_method(public_api, types, args, kwargs) 1740 1741 if result is not NotImplemented: 1742 return result 1743 1744 func_name = f"{public_api.__module__}.{public_api.__name__}" 1745 msg = ( 1746 f"no implementation found for '{func_name}' on types that implement " 1747 f"__torch_function__: {[type(arg) for arg in overloaded_args]}" 1748 ) 1749 if _is_torch_function_mode_enabled(): 1750 msg += f" nor in mode {_get_current_function_mode()}" 1751 raise TypeError(msg) 1752 1753 1754has_torch_function = _add_docstr( 1755 _has_torch_function, 1756 r"""Check for __torch_function__ implementations in the elements of an iterable 1757 or if a __torch_function__ mode is enabled. Considers exact ``Tensor`` s 1758 and ``Parameter`` s non-dispatchable. Use this to guard a call to 1759 :func:`handle_torch_function`; don't use it to test if something 1760 is Tensor-like, use :func:`is_tensor_like` instead. 1761 Arguments 1762 --------- 1763 relevant_args : iterable 1764 Iterable or arguments to check for __torch_function__ methods. 1765 Returns 1766 ------- 1767 bool 1768 True if any of the elements of relevant_args have __torch_function__ 1769 implementations, False otherwise. 1770 See Also 1771 ________ 1772 torch.is_tensor_like 1773 Checks if something is a Tensor-like, including an exact ``Tensor``. 1774 """, 1775) 1776 1777has_torch_function_unary = _add_docstr( 1778 _has_torch_function_unary, 1779 r"""Special case of `has_torch_function` for single inputs. 1780 Instead of: 1781 `has_torch_function((t,))` 1782 call: 1783 `has_torch_function_unary(t)` 1784 which skips unnecessary packing and unpacking work. 1785 """, 1786) 1787 1788has_torch_function_variadic = _add_docstr( 1789 _has_torch_function_variadic, 1790 r"""Special case of `has_torch_function` that skips tuple creation. 1791 1792 This uses the METH_FASTCALL protocol introduced in Python 3.7 1793 1794 Instead of: 1795 `has_torch_function((a, b))` 1796 call: 1797 `has_torch_function_variadic(a, b)` 1798 which skips unnecessary packing and unpacking work. 1799 """, 1800) 1801 1802 1803@functools.lru_cache(None) 1804def _get_overridable_functions() -> ( 1805 Tuple[Dict[Any, List[Callable]], Dict[Callable, str]] 1806): 1807 overridable_funcs = collections.defaultdict(list) 1808 index = {} 1809 tested_namespaces = [ 1810 ("torch", torch, torch.__all__), 1811 ("torch.functional", torch.functional, torch.functional.__all__), 1812 ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)), 1813 ("torch.nn.init", torch.nn.init, dir(torch.nn.init)), 1814 ("torch.Tensor", torch.Tensor, dir(torch.Tensor)), 1815 ("torch.linalg", torch.linalg, dir(torch.linalg)), 1816 ("torch.fft", torch.fft, dir(torch.fft)), 1817 ("torch.special", torch.special, dir(torch.special)), 1818 ] 1819 for namespace_str, namespace, ns_funcs in tested_namespaces: 1820 for func_name in ns_funcs: 1821 ignore = False 1822 # ignore private functions or functions that are deleted in torch.__init__ 1823 if namespace is not torch.Tensor: 1824 if func_name.startswith("__"): 1825 continue 1826 elif func_name.startswith("_"): 1827 ignore = True 1828 elif func_name.endswith("_"): 1829 ignore = True 1830 elif not func_name[0].islower(): 1831 ignore = True 1832 elif func_name == "unique_dim": 1833 continue 1834 else: 1835 func = getattr(namespace, func_name) 1836 if getattr(object, func_name, None) == func: 1837 continue 1838 if func_name == "__weakref__": 1839 continue 1840 func = getattr(namespace, func_name) 1841 if namespace is torch.Tensor and getattr(object, func_name, None) == func: 1842 continue 1843 # ignore re-exported modules 1844 if isinstance(func, types.ModuleType): 1845 continue 1846 # ignore __future__ imports 1847 if isinstance(func, __future__._Feature): 1848 continue 1849 1850 if not callable(func) and hasattr(func, "__get__"): 1851 index[func.__get__] = f"{namespace_str}.{func_name}.__get__" 1852 index[func.__set__] = f"{namespace_str}.{func_name}.__set__" 1853 if ignore: 1854 continue 1855 if func.__get__ in get_ignored_functions(): 1856 msg = ( 1857 "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " 1858 "but still has an explicit override" 1859 ) 1860 assert func.__get__ not in get_testing_overrides(), msg.format( 1861 namespace, func.__name__ 1862 ) 1863 continue 1864 else: 1865 overridable_funcs[func].append(func.__get__) 1866 continue 1867 1868 if not callable(func): 1869 continue 1870 1871 index[func] = f"{namespace_str}.{func_name}" 1872 1873 if ignore: 1874 continue 1875 1876 # cannot be overriden by __torch_function__ 1877 if func in get_ignored_functions(): 1878 msg = ( 1879 "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " 1880 "but still has an explicit override" 1881 ) 1882 assert func not in get_testing_overrides(), msg.format( 1883 namespace, func.__name__ 1884 ) 1885 continue 1886 overridable_funcs[namespace].append(func) 1887 return overridable_funcs, index 1888 1889 1890@_disable_user_warnings 1891def get_overridable_functions() -> Dict[Any, List[Callable]]: 1892 """List functions that are overridable via __torch_function__ 1893 1894 Returns 1895 ------- 1896 Dict[Any, List[Callable]] 1897 A dictionary that maps namespaces that contain overridable functions 1898 to functions in that namespace that can be overridden. 1899 """ 1900 return _get_overridable_functions()[0] 1901 1902 1903@_disable_user_warnings 1904def resolve_name(f): 1905 """Get a human readable string name for a function passed to 1906 __torch_function__ 1907 1908 Arguments 1909 --------- 1910 f : Callable 1911 Function to resolve the name of. 1912 1913 Returns 1914 ------- 1915 str 1916 Name of the function; if eval'ed it should give back the input 1917 function. 1918 """ 1919 if isinstance(f, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)): 1920 return str(f) 1921 return _get_overridable_functions()[1].get(f) 1922 1923 1924@functools.lru_cache(None) 1925def _get_tensor_methods() -> Set[Callable]: 1926 """Returns a set of the overridable methods on ``torch.Tensor``""" 1927 overridable_funcs = get_overridable_functions() 1928 methods = set(overridable_funcs[torch.Tensor]) 1929 return methods 1930 1931 1932@_disable_user_warnings 1933def is_tensor_method_or_property(func: Callable) -> bool: 1934 """ 1935 Returns True if the function passed in is a handler for a 1936 method or property belonging to ``torch.Tensor``, as passed 1937 into ``__torch_function__``. 1938 1939 .. note:: 1940 For properties, their ``__get__`` method must be passed in. 1941 1942 This may be needed, in particular, for the following reasons: 1943 1944 1. Methods/properties sometimes don't contain a `__module__` slot. 1945 2. They require that the first passed-in argument is an instance 1946 of ``torch.Tensor``. 1947 1948 Examples 1949 -------- 1950 >>> is_tensor_method_or_property(torch.Tensor.add) 1951 True 1952 >>> is_tensor_method_or_property(torch.add) 1953 False 1954 """ 1955 return func in _get_tensor_methods() or func.__name__ == "__get__" 1956 1957 1958def is_tensor_like(inp): 1959 """ 1960 Returns ``True`` if the passed-in input is a Tensor-like. 1961 1962 Currently, this occurs whenever there's a ``__torch_function__`` 1963 attribute on the type of the input. 1964 1965 Examples 1966 -------- 1967 A subclass of tensor is generally a Tensor-like. 1968 1969 >>> class SubTensor(torch.Tensor): ... 1970 >>> is_tensor_like(SubTensor([0])) 1971 True 1972 1973 Built-in or user types aren't usually Tensor-like. 1974 1975 >>> is_tensor_like(6) 1976 False 1977 >>> is_tensor_like(None) 1978 False 1979 >>> class NotATensor: ... 1980 >>> is_tensor_like(NotATensor()) 1981 False 1982 1983 But, they can be made Tensor-like by implementing __torch_function__. 1984 1985 >>> class TensorLike: 1986 ... @classmethod 1987 ... def __torch_function__(cls, func, types, args, kwargs): 1988 ... return -1 1989 >>> is_tensor_like(TensorLike()) 1990 True 1991 """ 1992 return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__") 1993 1994 1995class TorchFunctionMode: 1996 """ 1997 A ``TorchFunctionMode`` allows you to override the meaning of all 1998 ``__torch_function__`` overrideable functions within a dynamic scope, 1999 without having to actually create a tensor subclass or manually 2000 monkey-patch functions in the PyTorch API. Some common situations 2001 where you should use a mode: 2002 2003 * You want to override the meaning of factory functions, or other 2004 functions that do not otherwise take a tensor as an argument 2005 (these cannot be overridden with tensor subclasses). 2006 2007 * You want to override the behavior of all functions without needing 2008 to wrap your inputs in tensor subclasses; e.g., if you are just 2009 interested in logging intermediate computations. 2010 2011 * You want to control the order of execution of various tensor 2012 subclasses explicitly, rather than implicitly via the return of 2013 ``NotImplemented``. 2014 2015 Independent subclasses of :class:`TorchFunctionMode` are compositional: 2016 modes can be pushed onto a stack using ``with MyMode():``. 2017 When you call functions in the PyTorch API inside your 2018 ``__torch_function__`` implementation, by default, they will forward on to 2019 the next mode on the mode stack. If you want recursively call back into 2020 your current ``__torch_function__`` implementation, either explicitly 2021 invoke ``self.__torch_function__(...)``, or use the context manager 2022 ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch 2023 API self-referential (beware of infinite loops, in this case!) 2024 """ 2025 2026 inner: "TorchFunctionMode" 2027 2028 # Force metaclass to generate constructor at the base of the hierarchy 2029 def __init__(self) -> None: 2030 pass 2031 2032 def __torch_function__(self, func, types, args=(), kwargs=None): 2033 raise NotImplementedError 2034 2035 def __enter__(self): 2036 _push_mode(self) 2037 return self 2038 2039 def __exit__(self, exc_type, exc_val, exc_tb): 2040 _pop_mode() 2041 2042 @classmethod 2043 def push(cls, *args, **kwargs): 2044 warnings.warn( 2045 "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`" 2046 ) 2047 instance = cls(*args, **kwargs) 2048 return instance 2049 2050 2051def _get_current_function_mode(): 2052 stack_len = _len_torch_function_stack() 2053 return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None 2054 2055 2056def _get_current_function_mode_stack(): 2057 stack_len = _len_torch_function_stack() 2058 return [_get_function_stack_at(i) for i in range(stack_len)] 2059 2060 2061def _push_mode(mode): 2062 _push_on_torch_function_stack(mode) 2063 2064 2065def _pop_mode(): 2066 old = _pop_torch_function_stack() 2067 return old 2068 2069 2070@contextlib.contextmanager 2071def _pop_mode_temporarily(): 2072 old = _pop_mode() 2073 try: 2074 yield old 2075 finally: 2076 _push_mode(old) 2077 2078 2079class BaseTorchFunctionMode(TorchFunctionMode): 2080 def __torch_function__(self, func, types, args=(), kwargs=None): 2081 if kwargs is None: 2082 kwargs = {} 2083 return func(*args, **kwargs) 2084 2085 2086@contextlib.contextmanager 2087def enable_reentrant_dispatch(): 2088 # NB: this can't simply be 2089 # `enable_reentrant_dispatch = torch._C._RestorePythonTLSSnapshot` 2090 # because: 2091 # 1. torch._C._RestorePythonTLSSnapshot is unavailable when this file 2092 # initially gets imported. Probably an import order thing. 2093 # 2. enable_reentrant_dispatch is technically public API; assigning 2094 # it the object would change the __module__ to look private. 2095 with torch._C._RestorePythonTLSSnapshot(): 2096 try: 2097 yield 2098 finally: 2099 pass 2100