xref: /aosp_15_r20/external/pytorch/torch/overrides.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2Python implementation of ``__torch_function__``
3
4While most of the torch API and handling for ``__torch_function__`` happens
5at the C++ level, some of the torch API is written in Python so we need
6python-level handling for ``__torch_function__`` overrides as well. The main
7developer-facing functionality in this file are handle_torch_function and
8has_torch_function. See torch/functional.py and test/test_overrides.py
9for usage examples.
10
11Note
12----
13heavily inspired by NumPy's ``__array_function__`` (see:
14https://github.com/pytorch/pytorch/issues/24015 and
15https://www.numpy.org/neps/nep-0018-array-function-protocol.html
16)
17
18If changing this file in a way that can affect ``__torch_function__`` overhead,
19please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
20instructions in the ``README.md`` in that directory.
21"""
22
23import __future__  # noqa: F404
24
25import collections
26import contextlib
27import functools
28import types
29import warnings
30from functools import wraps
31from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type
32
33import torch
34from torch._C import (
35    _add_docstr,
36    _get_function_stack_at,
37    _has_torch_function,
38    _has_torch_function_unary,
39    _has_torch_function_variadic,
40    _is_torch_function_mode_enabled,
41    _len_torch_function_stack,
42    _pop_torch_function_stack,
43    _push_on_torch_function_stack,
44)
45
46
47__all__ = [
48    "get_ignored_functions",
49    "get_overridable_functions",
50    "get_testing_overrides",
51    "handle_torch_function",
52    "has_torch_function",
53    "resolve_name",
54    "is_tensor_like",
55    "is_tensor_method_or_property",
56    "wrap_torch_function",
57    "enable_reentrant_dispatch",
58]
59
60
61def _disable_user_warnings(
62    func: Callable,
63    regex: str = ".*is deprecated, please use.*",
64    module: str = "torch",
65) -> Callable:
66    """
67    Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
68    given ``regex`` pattern.
69
70    Arguments
71    ---------
72    func : function
73        Function to disable the warnings for.
74    regex : str
75        A regex pattern compilable by ``re.compile``. This is used to match the ``UserWarning`` message.
76    module : str
77        The python module to which the filtering should be restricted.
78
79    Returns
80    -------
81    function
82        The wrapped function.
83    """
84
85    @wraps(func)
86    def wrapper(*args, **kwargs):
87        with warnings.catch_warnings():
88            warnings.filterwarnings(
89                "ignore", category=UserWarning, message=regex, module=module
90            )
91            return func(*args, **kwargs)
92
93    return wrapper
94
95
96@functools.lru_cache(None)
97@_disable_user_warnings
98def get_ignored_functions() -> Set[Callable]:
99    """
100    Return public functions that cannot be overridden by ``__torch_function__``.
101
102    Returns
103    -------
104    Set[Callable]
105        A tuple of functions that are publicly available in the torch API but cannot
106        be overridden with ``__torch_function__``. Mostly this is because none of the
107        arguments of these functions are tensors or tensor-likes.
108
109    Examples
110    --------
111    >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
112    True
113    >>> torch.add in torch.overrides.get_ignored_functions()
114    False
115    """
116    Tensor = torch.Tensor
117    return {
118        torch.typename,
119        torch.is_tensor,
120        torch.is_storage,
121        torch.set_default_tensor_type,
122        torch.set_default_device,
123        torch.get_default_device,
124        torch.set_rng_state,
125        torch.get_rng_state,
126        torch.manual_seed,
127        torch.initial_seed,
128        torch.seed,
129        torch.save,
130        torch.load,
131        torch.set_printoptions,
132        torch.fork,
133        torch.get_default_dtype,
134        torch.get_num_interop_threads,
135        torch.get_num_threads,
136        torch.init_num_threads,
137        torch.import_ir_module,
138        torch.import_ir_module_from_buffer,
139        torch.is_anomaly_enabled,
140        torch.is_anomaly_check_nan_enabled,
141        torch.is_grad_enabled,
142        torch.merge_type_from_type_comment,
143        torch.parse_ir,
144        torch.parse_schema,
145        torch.parse_type_comment,
146        torch.set_anomaly_enabled,
147        torch.set_flush_denormal,
148        torch.set_num_interop_threads,
149        torch.set_num_threads,
150        torch.wait,
151        torch.as_tensor,
152        torch.from_numpy,
153        torch.get_device,
154        torch.tensor,
155        torch.default_generator,
156        torch.has_cuda,
157        torch.has_cudnn,
158        torch.has_lapack,
159        torch.device,
160        torch.dtype,
161        torch.finfo,
162        torch.has_mkl,
163        torch.has_mps,
164        torch.has_mkldnn,
165        torch.has_openmp,
166        torch.iinfo,
167        torch.memory_format,
168        torch.qscheme,
169        torch.set_grad_enabled,
170        torch.no_grad,
171        torch.enable_grad,
172        torch.inference_mode,
173        torch.is_inference_mode_enabled,
174        torch.layout,
175        torch.align_tensors,
176        torch.arange,
177        torch.as_strided,
178        torch.bartlett_window,
179        torch.blackman_window,
180        torch.broadcast_shapes,
181        torch.can_cast,
182        torch.compile,
183        torch.cudnn_affine_grid_generator,
184        torch.cudnn_batch_norm,
185        torch.cudnn_convolution,
186        torch.cudnn_convolution_transpose,
187        torch.cudnn_convolution_relu,
188        torch.cudnn_convolution_add_relu,
189        torch.cudnn_grid_sampler,
190        torch.cudnn_is_acceptable,
191        torch.empty,
192        torch.empty_permuted,
193        torch.empty_strided,
194        torch.empty_quantized,
195        torch.export.export,
196        torch.export.load,
197        torch.export.register_dataclass,
198        torch.export.save,
199        torch.eye,
200        torch.fft.fftfreq,
201        torch.fft.rfftfreq,
202        torch.from_file,
203        torch.full,
204        torch.fill,
205        torch.hamming_window,
206        torch.hann_window,
207        torch.kaiser_window,
208        torch.linspace,
209        torch.logspace,
210        torch.mkldnn_adaptive_avg_pool2d,
211        torch.mkldnn_convolution,
212        torch.mkldnn_max_pool2d,
213        torch.mkldnn_max_pool3d,
214        torch.mkldnn_linear_backward_weights,
215        torch.mkldnn_rnn_layer,
216        torch.normal,
217        torch.ones,
218        torch.promote_types,
219        torch.rand,
220        torch.randn,
221        torch.randint,
222        torch.randperm,
223        torch.range,
224        torch.result_type,
225        torch.scalar_tensor,
226        torch.sparse_coo_tensor,
227        torch.sparse_compressed_tensor,
228        torch.sparse_csr_tensor,
229        torch.sparse_csc_tensor,
230        torch.sparse_bsr_tensor,
231        torch.sparse_bsc_tensor,
232        torch.sym_constrain_range,
233        torch.sym_constrain_range_for_size,
234        torch.tril_indices,
235        torch.triu_indices,
236        torch.vander,
237        torch.zeros,
238        torch._jit_internal.boolean_dispatch,
239        torch.nn.functional.assert_int_or_pair,
240        torch.nn.functional.upsample,
241        torch.nn.functional.upsample_bilinear,
242        torch.nn.functional.upsample_nearest,
243        torch.nn.functional.has_torch_function,
244        torch.nn.functional.has_torch_function_unary,
245        torch.nn.functional.has_torch_function_variadic,
246        torch.nn.functional.handle_torch_function,
247        torch.nn.functional.sigmoid,
248        torch.nn.functional.hardsigmoid,
249        torch.nn.functional.tanh,
250        torch.nn.functional._canonical_mask,
251        torch.nn.functional._none_or_dtype,
252        # Doesn't actually take or return tensor arguments
253        torch.nn.init.calculate_gain,
254        # These are deprecated; don't test them
255        torch.nn.init.uniform,
256        torch.nn.init.normal,
257        torch.nn.init.constant,
258        torch.nn.init.eye,
259        torch.nn.init.dirac,
260        torch.nn.init.xavier_uniform,
261        torch.nn.init.xavier_normal,
262        torch.nn.init.kaiming_uniform,
263        torch.nn.init.kaiming_normal,
264        torch.nn.init.orthogonal,
265        torch.nn.init.sparse,
266        torch.nested.to_padded_tensor,
267        has_torch_function,
268        handle_torch_function,
269        torch.set_autocast_enabled,
270        torch.is_autocast_enabled,
271        torch.set_autocast_dtype,
272        torch.get_autocast_dtype,
273        torch.clear_autocast_cache,
274        torch.set_autocast_cpu_enabled,
275        torch.is_autocast_cpu_enabled,
276        torch.set_autocast_xla_enabled,
277        torch.is_autocast_xla_enabled,
278        torch.set_autocast_ipu_enabled,
279        torch.is_autocast_ipu_enabled,
280        torch.set_autocast_cpu_dtype,
281        torch.get_autocast_cpu_dtype,
282        torch.set_autocast_ipu_dtype,
283        torch.get_autocast_ipu_dtype,
284        torch.get_autocast_gpu_dtype,
285        torch.set_autocast_gpu_dtype,
286        torch.get_autocast_xla_dtype,
287        torch.set_autocast_xla_dtype,
288        torch.autocast_increment_nesting,
289        torch.autocast_decrement_nesting,
290        torch.is_autocast_cache_enabled,
291        torch.set_autocast_cache_enabled,
292        torch.nn.functional.hardswish,
293        torch.is_vulkan_available,
294        torch.are_deterministic_algorithms_enabled,
295        torch.use_deterministic_algorithms,
296        torch.is_deterministic_algorithms_warn_only_enabled,
297        torch.set_deterministic_debug_mode,
298        torch.get_device_module,
299        torch.get_deterministic_debug_mode,
300        torch.set_float32_matmul_precision,
301        torch.get_float32_matmul_precision,
302        torch.unify_type_list,
303        torch.is_warn_always_enabled,
304        torch.set_warn_always,
305        torch.vitals_enabled,
306        torch.set_vital,
307        torch.read_vitals,
308        torch.vmap,
309        torch.cond,
310        torch.frombuffer,
311        torch.asarray,
312        torch._functional_sym_constrain_range,
313        torch._make_dep_token,
314        Tensor.__delitem__,
315        Tensor.__dir__,
316        Tensor.__getattribute__,
317        Tensor.__init__,
318        Tensor.__iter__,
319        Tensor.__init_subclass__,
320        Tensor.__delattr__,
321        Tensor.__setattr__,
322        Tensor.__torch_function__,
323        Tensor.__torch_dispatch__,
324        Tensor.__new__,
325        Tensor.__class__,
326        Tensor.__subclasshook__,
327        Tensor.__hash__,
328        Tensor.as_subclass,
329        Tensor.eig,
330        Tensor.lstsq,
331        Tensor.reinforce,
332        Tensor.new,
333        Tensor.new_tensor,
334        Tensor.new_empty,
335        Tensor.new_empty_strided,
336        Tensor.new_zeros,
337        Tensor.new_ones,
338        Tensor.new_full,
339        Tensor._make_subclass,
340        Tensor.solve,
341        Tensor.symeig,
342        Tensor.stride,
343        Tensor.unflatten,
344        Tensor.to_sparse_coo,
345        Tensor.to_sparse_csr,
346        Tensor.to_sparse_csc,
347        Tensor.to_sparse_bsr,
348        Tensor.to_sparse_bsc,
349        Tensor._to_sparse,
350        Tensor._to_sparse_csr,
351        Tensor._to_sparse_csc,
352        Tensor._to_sparse_bsr,
353        Tensor._to_sparse_bsc,
354        Tensor._typed_storage,
355        Tensor._reduce_ex_internal,
356        Tensor._fix_weakref,
357        Tensor._view_func,
358        Tensor._view_func_unsafe,
359        Tensor._rev_view_func_unsafe,
360        Tensor._make_wrapper_subclass,
361        Tensor._python_dispatch.__get__,
362        Tensor._has_symbolic_sizes_strides.__get__,
363        Tensor._conj,
364        Tensor._conj_physical,
365        Tensor._lazy_clone,
366        Tensor._neg_view,
367        Tensor._is_zerotensor,
368        Tensor._is_all_true,
369        Tensor._is_any_true,
370        Tensor._addmm_activation,
371        Tensor.to_padded_tensor,
372        Tensor._use_count,
373    }
374
375
376@functools.lru_cache(None)
377def get_default_nowrap_functions() -> Set[Callable]:
378    """
379    Return public functions that do not wrap in a subclass when invoked by
380    the default ``Tensor.__torch_function__`` that preserves subclasses.  Typically,
381    these functions represent field accesses (i.e., retrieving a Tensor that
382    is stored somewhere on the Tensor) as opposed to computation.  Users of
383    these functions expect object identity to be preserved over multiple accesses
384    (e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
385    the fly every time (furthermore, the tensor stored here might already be
386    the subclass, in which case wrapping really ought not to happen).
387
388    Not ALL property accessors have this property; for example ``Tensor.T`` actually
389    just creates a new transposed tensor on the fly, and so we SHOULD interpose on
390    these calls (you need to check the implementation of the function to see if
391    this is the case or not).  Additionally, if a property accessor doesn't return a Tensor,
392    it doesn't have to be on this list (though it is harmless if it is).
393    """
394    Tensor = torch.Tensor
395    return {
396        Tensor._base.__get__,
397        Tensor.grad.__get__,
398        Tensor._grad.__get__,
399    }
400
401
402@functools.lru_cache(None)
403@_disable_user_warnings
404def get_testing_overrides() -> Dict[Callable, Callable]:
405    """Return a dict containing dummy overrides for all overridable functions
406
407    Returns
408    -------
409    Dict[Callable, Callable]
410        A dictionary that maps overridable functions in the PyTorch API to
411        lambda functions that have the same signature as the real function
412        and unconditionally return -1. These lambda functions are useful
413        for testing API coverage for a type that defines ``__torch_function__``.
414
415    Examples
416    --------
417    >>> import inspect
418    >>> my_add = torch.overrides.get_testing_overrides()[torch.add]
419    >>> inspect.signature(my_add)
420    <Signature (input, other, out=None)>
421    """
422    # Every function in the PyTorchAPI that can be overriden needs an entry
423    # in this dict.
424    #
425    # Optimally we would use inspect to get the function signature and define
426    # the lambda function procedurally but that is blocked by generating
427    # function signatures for native kernels that can be consumed by inspect.
428    # See Issue #28233.
429    Tensor = torch.Tensor
430    ret: Dict[Callable, Callable] = {
431        torch.abs: lambda input, out=None: -1,
432        torch.absolute: lambda input, out=None: -1,
433        torch.adaptive_avg_pool1d: lambda input, output_size: -1,
434        torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
435        torch.acos: lambda input, out=None: -1,
436        torch.adjoint: lambda input: -1,
437        torch.arccos: lambda input, out=None: -1,
438        torch.acosh: lambda input, out=None: -1,
439        torch.arccosh: lambda input, out=None: -1,
440        torch.add: lambda input, other, out=None: -1,
441        torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
442        torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
443        torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1,
444        torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
445        torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1,
446        torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1,
447        torch.affine_grid_generator: lambda theta, size, align_corners: -1,
448        torch.all: lambda input, dim=None: -1,
449        torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1,
450        torch.alpha_dropout: lambda input, p, train, inplace=False: -1,
451        torch.amax: lambda input, dim=None: -1,
452        torch.amin: lambda input, dim=None: -1,
453        torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1,
454        torch.angle: lambda input, out=None: -1,
455        torch.any: lambda input, dim=None, keepdim=False, out=None: -1,
456        torch.argmax: lambda input: -1,
457        torch.argmin: lambda input: -1,
458        torch.argsort: lambda input, dim=None: -1,
459        torch.asin: lambda input, out=None: -1,
460        torch._assert_async: lambda input, msg: -1,
461        torch.arcsin: lambda input, out=None: -1,
462        torch.asinh: lambda input, out=None: -1,
463        torch.arcsinh: lambda input, out=None: -1,
464        torch.atan: lambda input, out=None: -1,
465        torch.arctan: lambda input, out=None: -1,
466        torch.atan2: lambda input, other, out=None: -1,
467        torch.arctan2: lambda input, other, out=None: -1,
468        torch.atanh: lambda input, out=None: -1,
469        torch.arctanh: lambda input, out=None: -1,
470        torch.atleast_1d: lambda *tensors: -1,
471        torch.atleast_2d: lambda *tensors: -1,
472        torch.atleast_3d: lambda *tensors: -1,
473        torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1,
474        torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
475        torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1,
476        torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1,
477        torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1,
478        torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1,
479        torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
480        torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
481        torch.batch_norm_stats: lambda input, eps: -1,
482        torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
483        torch.bernoulli: lambda input, generator=None, out=None: -1,
484        torch.bilinear: lambda input1, input2, weight, bias: -1,
485        torch.binary_cross_entropy_with_logits: (
486            lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
487        ),
488        torch.bincount: lambda input, weights=None, minlength=0: -1,
489        torch.binomial: lambda count, prob, generator=None: -1,
490        torch.bitwise_and: lambda input, other, out=None: -1,
491        torch.bitwise_not: lambda input, out=None: -1,
492        torch.bitwise_or: lambda input, other, out=None: -1,
493        torch.bitwise_xor: lambda input, other, out=None: -1,
494        torch.bitwise_left_shift: lambda input, other, out=None: -1,
495        torch.bitwise_right_shift: lambda input, other, out=None: -1,
496        torch.block_diag: lambda *tensors: -1,
497        torch.bmm: lambda input, mat2, out=None: -1,
498        torch.broadcast_tensors: lambda *tensors: -1,
499        torch.broadcast_to: lambda self, size: -1,
500        torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
501        torch.cartesian_prod: lambda *tensors: -1,
502        torch.cat: lambda tensors, dim=0, out=None: -1,
503        torch.concat: lambda tensors, dim=0, out=None: -1,  # alias for torch.cat
504        torch.concatenate: lambda tensors, dim=0, out=None: -1,  # alias for torch.concatenate
505        torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
506        torch.ceil: lambda input, out=None: -1,
507        torch.celu: lambda input, alpha=1.0, inplace=False: -1,
508        torch.chain_matmul: lambda *matrices, out=None: -1,
509        torch.channel_shuffle: lambda input, groups: -1,
510        torch.cholesky: lambda input, upper=False, out=None: -1,
511        torch.linalg.cholesky: lambda input, out=None: -1,
512        torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
513        torch.cholesky_inverse: lambda input, upper=False, out=None: -1,
514        torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1,
515        torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1,
516        torch.chunk: lambda input, chunks, dim=0: -1,
517        torch.clamp: lambda input, min=None, max=None, out=None: -1,
518        torch.clip: lambda input, min=None, max=None, out=None: -1,
519        torch.clamp_min: lambda input, min, out=None: -1,
520        torch.clamp_max: lambda input, max, out=None: -1,
521        torch.column_stack: lambda tensors, out=None: -1,
522        torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
523        torch.clone: lambda input: -1,
524        torch.combinations: lambda input, r=2, with_replacement=False: -1,
525        torch.complex: lambda real, imag: -1,
526        torch.copysign: lambda input, other, out=None: -1,
527        torch.polar: lambda abs, ang: -1,
528        torch.linalg.cond: lambda input, ord=None: -1,
529        torch.conj: lambda input, out=None: -1,
530        torch.conj_physical: lambda input, out=None: -1,
531        torch.resolve_conj: lambda input, out=None: -1,
532        torch.resolve_neg: lambda input, out=None: -1,
533        torch.constant_pad_nd: lambda input, pad, value=0: -1,
534        torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
535        torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
536        torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
537        torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1,
538        torch.conv_tbc: lambda input, weight, bias, pad=0: -1,
539        torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
540        torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
541        torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
542        torch.corrcoef: lambda input: -1,
543        torch.cos: lambda input, out=None: -1,
544        torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
545        torch.cosh: lambda input, out=None: -1,
546        torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
547        torch.count_nonzero: lambda input: -1,
548        torch.cross: lambda input, other, dim=None, out=None: -1,
549        torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
550        torch.ctc_loss: (
551            lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
552        ),
553        torch.cummax: lambda input, dim, out=None: -1,
554        torch.cummin: lambda input, dim, out=None: -1,
555        torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
556        torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
557        torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
558        torch.logcumsumexp: lambda input, dim, out=None: -1,
559        torch.deg2rad: lambda input, out=None: -1,
560        torch.dequantize: lambda input: -1,
561        torch.det: lambda input: -1,
562        torch.linalg.det: lambda input: -1,  # alias for torch.det  # type: ignore[attr-defined]
563        torch.detach: lambda input: -1,
564        torch.diag: lambda input, diagonal=0, out=None: -1,
565        torch.diag_embed: lambda input, diagonal=0, out=None: -1,
566        torch.diagflat: lambda input, offset=0: -1,
567        torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
568        torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
569        torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
570        torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
571        torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1,
572        torch.digamma: lambda input, out=None: -1,
573        torch.dist: lambda input, other, p=2: -1,
574        torch.div: lambda input, other, rounding_mode=None, out=None: -1,
575        torch.divide: lambda input, other, rounding_mode=None, out=None: -1,
576        torch.dot: lambda input, other, out=None: -1,
577        torch.dropout: lambda input, p, train, inplace=False: -1,
578        torch.dsmm: lambda input, mat2: -1,
579        torch.hsmm: lambda mat1, mat2: -1,
580        torch.dsplit: lambda input, indices_or_sections: -1,
581        torch.dstack: lambda tensors, out=None: -1,
582        torch.linalg.eig: lambda input, out=None: -1,
583        torch.linalg.eigvals: lambda input, out=None: -1,
584        torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
585        torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
586        torch.einsum: lambda equation, *operands: -1,
587        torch.embedding: (
588            lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1  # noqa: B950
589        ),
590        torch.embedding_bag: (
591            lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1  # noqa: B950
592        ),
593        torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
594        torch.eq: lambda input, other, out=None: -1,
595        torch.equal: lambda input, other: -1,
596        torch.erf: lambda input, out=None: -1,
597        torch.erfc: lambda input, out=None: -1,
598        torch.erfinv: lambda input, out=None: -1,
599        torch.exp: lambda input, out=None: -1,
600        torch.exp2: lambda input, out=None: -1,
601        torch.expm1: lambda input, out=None: -1,
602        torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
603        torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
604        torch.fused_moving_avg_obs_fake_quant: (
605            lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1  # noqa: B950
606        ),
607        torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
608        torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
609        torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,  # noqa: B950
610        torch.fbgemm_linear_int8_weight_fp32_activation: (
611            lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
612        ),
613        torch.fbgemm_linear_quantize_weight: lambda input: -1,
614        torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
615        torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
616        torch.feature_alpha_dropout: lambda input, p, train: -1,
617        torch.feature_dropout: lambda input, p, train: -1,
618        torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
619        torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1,
620        torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1,
621        torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1,
622        torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1,
623        torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
624        torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
625        torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1,
626        torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1,
627        torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1,
628        torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1,
629        torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1,
630        torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1,
631        torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
632        torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
633        torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
634        torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
635        torch.fft.fftshift: lambda input, dim=None: -1,
636        torch.fft.ifftshift: lambda input, dim=None: -1,
637        torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
638        torch.fix: lambda input, out=None: -1,
639        torch.flatten: lambda input, start_dim=0, end_dim=-1: -1,
640        torch.flip: lambda input, dims: -1,
641        torch.fliplr: lambda input: -1,
642        torch.flipud: lambda input: -1,
643        torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
644        torch.floor: lambda input, out=None: -1,
645        torch.floor_divide: lambda input, other: -1,
646        torch.float_power: lambda input, exponent, out=None: -1,
647        torch.fmod: lambda input, other, out=None: -1,
648        torch.frac: lambda input, out=None: -1,
649        torch.frexp: lambda input, out=None: -1,
650        torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,  # noqa: B950
651        torch._functional_assert_async: lambda input, msg, dep_token: -1,
652        torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
653        torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
654        torch.gcd: lambda input, other, out=None: -1,
655        torch.ge: lambda input, other, out=None: -1,
656        torch.greater_equal: lambda input, other, out=None: -1,
657        torch.geqrf: lambda input, out=None: -1,
658        torch.i0: lambda input, out=None: -1,
659        torch.inner: lambda input, other, out=None: -1,
660        torch.outer: lambda input, vec2, out=None: -1,
661        torch.ger: lambda input, vec2, out=None: -1,  # alias for torch.outer
662        torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1,
663        torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
664        torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
665        torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
666        torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
667        torch.gru: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
668        torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
669        torch.gt: lambda input, other, out=None: -1,
670        torch.greater: lambda input, other, out=None: -1,
671        torch.hardshrink: lambda input, lambd=0.5: -1,
672        torch.heaviside: lambda input, values, out=None: -1,
673        torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1,  # noqa: B950
674        torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
675        torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
676        torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
677        torch.linalg.householder_product: lambda input, tau: -1,
678        torch.hspmm: lambda mat1, mat2, out=None: -1,
679        torch.hsplit: lambda input, indices_or_sections: -1,
680        torch.hstack: lambda tensors, out=None: -1,
681        torch.hypot: lambda input, other, out=None: -1,
682        torch.igamma: lambda input, other, out=None: -1,
683        torch.igammac: lambda input, other, out=None: -1,
684        torch.imag: lambda input, out=None: -1,
685        torch.index_add: lambda input, dim, index, source: -1,
686        torch.index_copy: lambda input, dim, index, source: -1,
687        torch.index_put: lambda input, indices, values, accumulate=False: -1,
688        torch.index_select: lambda input, dim, index, out=None: -1,
689        torch.index_fill: lambda input, dim, index, value: -1,
690        torch.index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1,
691        torch.isfinite: lambda tensor: -1,
692        torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
693        torch.isinf: lambda tensor: -1,
694        torch.isreal: lambda tensor: -1,
695        torch.isposinf: lambda input, out=None: -1,
696        torch.isneginf: lambda input, out=None: -1,
697        torch.instance_norm: (
698            lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
699        ),
700        torch.int_repr: lambda input: -1,
701        torch.inverse: lambda input, out=None: -1,
702        torch.linalg.inv: lambda input, out=None: -1,
703        torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
704        torch.is_complex: lambda input: -1,
705        torch.is_conj: lambda input: -1,
706        torch.is_neg: lambda input: -1,
707        torch.is_distributed: lambda input: -1,
708        torch.is_inference: lambda input: -1,
709        torch.is_floating_point: lambda input: -1,
710        torch.is_nonzero: lambda input: -1,
711        torch.is_same_size: lambda input, other: -1,
712        torch.is_signed: lambda input: -1,
713        torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
714        torch.isnan: lambda input: -1,
715        torch.istft: (
716            lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1  # noqa: B950
717        ),
718        torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
719        torch.kron: lambda input, other: -1,
720        torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
721        torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
722        torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1,
723        torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1,
724        torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
725        torch.lcm: lambda input, other, out=None: -1,
726        torch.ldexp: lambda input, other, out=None: -1,
727        torch.le: lambda input, other, out=None: -1,
728        torch.less_equal: lambda input, other, out=None: -1,
729        torch.lerp: lambda input, end, weight, out=None: -1,
730        torch.lgamma: lambda input, out=None: -1,
731        torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,  # noqa: B950
732        torch.log: lambda input, out=None: -1,
733        torch.log_softmax: lambda input, dim, dtype=None: -1,
734        torch.log10: lambda input, out=None: -1,
735        torch.log1p: lambda input, out=None: -1,
736        torch.log2: lambda input, out=None: -1,
737        torch.logaddexp: lambda input, other, out=None: -1,
738        torch.logaddexp2: lambda input, other, out=None: -1,
739        torch.logdet: lambda input: -1,
740        torch.xlogy: lambda x, y, out=None: -1,
741        torch.logical_and: lambda input, other, out=None: -1,
742        torch.logical_not: lambda input, out=None: -1,
743        torch.logical_or: lambda input, other, out=None: -1,
744        torch.logical_xor: lambda input, other, out=None: -1,
745        torch.logit: lambda input, eps=None: -1,
746        torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
747        torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1,
748        torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
749        torch.lt: lambda input, other, out=None: -1,
750        torch.less: lambda input, other, out=None: -1,
751        torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
752        torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
753        torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,  # type: ignore[attr-defined]  # noqa: B950
754        torch.masked_fill: lambda input, mask, value: -1,
755        torch.masked_scatter: lambda input, mask, source: -1,
756        torch.masked_select: lambda input, mask, out=None: -1,
757        torch.matmul: lambda input, other, out=None: -1,
758        torch.linalg.lu: lambda input, pivot=True, out=None: -1,
759        torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1,
760        torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1,
761        torch.linalg.lu_solve: lambda LU, pivots, B, left=True, adjoint=False, out=None: -1,
762        torch.linalg.matmul: lambda input, other, out=None: -1,  # alias for torch.matmul
763        torch.matrix_power: lambda input, n: -1,
764        torch.linalg.matrix_power: lambda input, n, out=None: -1,
765        torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
766        torch.linalg.multi_dot: lambda tensors, out=None: -1,
767        torch.matrix_exp: lambda input: -1,
768        torch.linalg.matrix_exp: lambda input: -1,
769        torch.max: lambda input, out=None: -1,
770        torch.maximum: lambda input, other, out=None: -1,
771        torch.fmax: lambda input, other, out=None: -1,
772        torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
773        torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
774        torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
775        torch.max_pool1d_with_indices: (
776            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
777        ),
778        torch.mean: lambda input, dim=None: -1,
779        torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
780        torch.median: lambda input, dim=None: -1,
781        torch.nanmedian: lambda input, dim=None: -1,
782        torch.meshgrid: lambda *tensors, **kwargs: -1,
783        torch.min: lambda input, out=None: -1,
784        torch.minimum: lambda input, other, out=None: -1,
785        torch.fmin: lambda input, other, out=None: -1,
786        torch.miopen_batch_norm: (
787            lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
788        ),
789        torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,  # noqa: B950
790        torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
791        torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
792        torch.miopen_convolution_transpose: (
793            lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
794        ),
795        torch.miopen_depthwise_convolution: (
796            lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
797        ),
798        torch.miopen_rnn: (
799            lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1  # noqa: B950
800        ),
801        torch.mm: lambda input, mat2, out=None: -1,
802        torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
803        torch.movedim: lambda input, source, destination: -1,
804        torch.moveaxis: lambda input, source, destination: -1,
805        torch.msort: lambda input, descending=False, out=None: -1,
806        torch.mul: lambda input, other, out=None: -1,
807        torch.multiply: lambda input, other, out=None: -1,
808        torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
809        torch.mv: lambda input, vec, out=None: -1,
810        torch.mvlgamma: lambda input, p: -1,
811        torch.narrow: lambda input, dim, start, length: -1,
812        torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
813        torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
814        torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
815        torch.native_dropout: lambda input, p, train: -1,
816        torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
817        torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
818        torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
819        torch.native_channel_shuffle: lambda input, groups: -1,
820        torch.ne: lambda input, other, out=None: -1,
821        torch.not_equal: lambda input, other, out=None: -1,
822        torch.neg: lambda input, out=None: -1,
823        torch.negative: lambda input, out=None: -1,
824        torch.nextafter: lambda input, other, out=None: -1,
825        torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
826        torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1,
827        torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1,
828        torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1,
829        torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1,
830        torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1,
831        torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1,
832        torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
833        torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
834        torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
835        torch.nn.functional.avg_pool2d: (
836            lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1  # noqa: B950
837        ),
838        torch.nn.functional.avg_pool3d: (
839            lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1  # noqa: B950
840        ),
841        torch.nn.functional.batch_norm: (
842            lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
843        ),
844        torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
845        torch.nn.functional.binary_cross_entropy: (
846            lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
847        ),
848        torch.nn.functional.binary_cross_entropy_with_logits: (
849            lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
850        ),
851        torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
852        torch.nn.functional.cosine_embedding_loss: (
853            lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
854        ),
855        torch.nn.functional.cross_entropy: (
856            lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1  # noqa: B950
857        ),
858        torch.nn.functional.ctc_loss: (
859            lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
860        ),
861        torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
862        torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
863        torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
864        torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
865        torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
866        torch.nn.functional.embedding: (
867            lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1  # noqa: B950
868        ),
869        torch.nn.functional.embedding_bag: (
870            lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1  # noqa: B950
871        ),
872        torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
873        torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
874        torch.nn.functional.fractional_max_pool2d: (
875            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1  # noqa: B950
876        ),
877        torch.nn.functional.fractional_max_pool2d_with_indices: (
878            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1  # noqa: B950
879        ),
880        torch.nn.functional.fractional_max_pool3d: (
881            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1  # noqa: B950
882        ),
883        torch.nn.functional.fractional_max_pool3d_with_indices: (
884            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1  # noqa: B950
885        ),
886        torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
887        torch.nn.functional.gelu: lambda input, approximate="none": -1,
888        torch.nn.functional.glu: lambda input, dim=-1: -1,
889        torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1,  # noqa: B950
890        torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
891        torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
892        torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
893        torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
894        torch.nn.functional.hinge_embedding_loss: (
895            lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
896        ),
897        torch.nn.functional.instance_norm: (
898            lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1  # noqa: B950
899        ),
900        torch.nn.functional.interpolate: (
901            lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1  # noqa: B950
902        ),
903        torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,  # noqa: B950
904        torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
905        torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
906        torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
907        torch.nn.functional.linear: lambda input, weight, bias=None: -1,
908        torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1,
909        torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
910        torch.nn.functional.logsigmoid: lambda input: -1,
911        torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
912        torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
913        torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
914        torch.nn.functional.margin_ranking_loss: (
915            lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
916        ),
917        torch.nn.functional.max_pool1d: (
918            lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
919        ),
920        torch.nn.functional.max_pool1d_with_indices: (
921            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
922        ),
923        torch.nn.functional.max_pool2d: (
924            lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
925        ),
926        torch.nn.functional.max_pool2d_with_indices: (
927            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
928        ),
929        torch.nn.functional.max_pool3d: (
930            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
931        ),
932        torch.nn.functional.max_pool3d_with_indices: (
933            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
934        ),
935        torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,  # noqa: B950
936        torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,  # noqa: B950
937        torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,  # noqa: B950
938        torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
939        torch.nn.functional.multi_head_attention_forward: (
940            lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1  # noqa: B950
941        ),
942        torch.nn.functional.multi_margin_loss: (
943            lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
944        ),
945        torch.nn.functional.multilabel_margin_loss: (
946            lambda input, target, size_average=None, reduce=None, reduction="mean": -1
947        ),
948        torch.nn.functional.multilabel_soft_margin_loss: (
949            lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
950        ),
951        torch.nn.functional.nll_loss: (
952            lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
953        ),
954        torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
955        torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
956        torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
957        torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
958        torch.nn.functional.poisson_nll_loss: (
959            lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1  # noqa: B950
960        ),
961        torch.nn.functional.prelu: lambda input, weight: -1,
962        torch.nn.functional.relu: lambda input, inplace=False: -1,
963        torch.nn.functional.relu6: lambda input, inplace=False: -1,
964        torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
965        torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,  # noqa: B950
966        torch.nn.functional.selu: lambda input, inplace=False: -1,
967        torch.nn.functional.silu: lambda input, inplace=False: -1,
968        torch.nn.functional.mish: lambda input, inplace=False: -1,
969        torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
970        torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1,  # noqa: B950
971        torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
972        torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,  # noqa: B950
973        torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
974        torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
975        torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
976        torch.nn.functional.softshrink: lambda input, lambd=0.5: -1,
977        torch.nn.functional.softsign: lambda input: -1,
978        torch.nn.functional.tanhshrink: lambda input: -1,
979        torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
980        torch.nn.functional.triplet_margin_loss: (
981            lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1  # noqa: B950
982        ),
983        torch.nn.functional.triplet_margin_with_distance_loss: (
984            lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
985        ),
986        torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
987        torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
988        torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
989        torch.nn.init.constant_: lambda tensor, val: -1,
990        torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1,  # noqa: B950
991        torch.nonzero: lambda input, as_tuple=False: -1,
992        torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
993        torch.argwhere: lambda input: -1,
994        torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
995        torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
996        torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
997        torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
998            -2,
999            -1,
1000        ), keepdim=False, out=None, dtype=None: -1,
1001        torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
1002        torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
1003        torch.numel: lambda input: -1,
1004        torch.orgqr: lambda input, tau: -1,
1005        torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
1006        torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
1007        torch.permute: lambda self, dim: -1,
1008        torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1,
1009        torch.pdist: lambda input, p=2: -1,
1010        torch.pinverse: lambda input, rcond=1e-15: -1,
1011        torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1,
1012        torch.pixel_shuffle: lambda input, upscale_factor: -1,
1013        torch.pixel_unshuffle: lambda input, downscale_factor: -1,
1014        torch.poisson: lambda input, generator=None: -1,
1015        torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
1016        torch.polygamma: lambda input, n, out=None: -1,
1017        torch.positive: lambda input, out=None: -1,
1018        torch.prelu: lambda input, weight: -1,
1019        torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
1020        torch.pow: lambda input, exponent, out=None: -1,
1021        torch.prod: lambda input, dtype=None: -1,
1022        torch.put: lambda input, index, source, accumulate=False: -1,
1023        torch.q_per_channel_axis: lambda input: -1,
1024        torch.q_per_channel_scales: lambda input: -1,
1025        torch.q_per_channel_zero_points: lambda input: -1,
1026        torch.q_scale: lambda input: -1,
1027        torch.q_zero_point: lambda input: -1,
1028        torch.qr: lambda input, some=True, out=None: -1,
1029        torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
1030        torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
1031        torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
1032        torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
1033        torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
1034        torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
1035        torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
1036        torch.quantized_gru_cell: (
1037            lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1  # noqa: B950
1038        ),
1039        torch.quantized_lstm_cell: (
1040            lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1  # noqa: B950
1041        ),
1042        torch.quantized_max_pool1d: (
1043            lambda input, kernel_size, stride=(), padding=(0,), dilation=(
1044                1,
1045            ), ceil_mode=False: -1
1046        ),
1047        torch.quantized_max_pool2d: (
1048            lambda input, kernel_size, stride=(), padding=(0, 0), dilation=(
1049                1,
1050                1,
1051            ), ceil_mode=False: -1
1052        ),
1053        torch.quantized_max_pool3d: (
1054            lambda input, kernel_size, stride=(), padding=(0, 0, 0), dilation=(
1055                1,
1056                1,
1057                1,
1058            ), ceil_mode=False: -1
1059        ),
1060        torch.quantized_rnn_relu_cell: (
1061            lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1  # noqa: B950
1062        ),
1063        torch.quantized_rnn_tanh_cell: (
1064            lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1  # noqa: B950
1065        ),
1066        torch.rad2deg: lambda input, out=None: -1,
1067        torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
1068        torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
1069        torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
1070        torch.ravel: lambda input: -1,
1071        torch.real: lambda input, out=None: -1,
1072        torch.vdot: lambda input, other, out=None: -1,
1073        torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
1074        torch.view_as_real: lambda input: -1,
1075        torch.view_as_complex: lambda input: -1,
1076        torch.reciprocal: lambda input, out=None: -1,
1077        torch.relu: lambda input, inplace=False: -1,
1078        torch.remainder: lambda input, other, out=None: -1,
1079        torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
1080        torch.repeat_interleave: lambda input, dim=None: -1,
1081        torch.reshape: lambda input, shape: -1,
1082        torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
1083        torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,  # noqa: B950
1084        torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
1085        torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,  # noqa: B950
1086        torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
1087        torch.roll: lambda input, shifts, dims=None: -1,
1088        torch.rot90: lambda input, k=1, dims=(0, 1): -1,
1089        torch.round: lambda input, out=None: -1,
1090        torch.row_stack: lambda tensors, out=None: -1,  # alias for torch.vstack
1091        torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
1092        torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
1093        torch.rsqrt: lambda input, out=None: -1,
1094        torch.rsub: lambda input, other, alpha=1: -1,
1095        torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
1096        torch.scatter: lambda input, dim, index, src: -1,
1097        torch.scatter_add: lambda input, dim, index, src: -1,
1098        torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
1099        torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
1100        torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,  # noqa: B950
1101        torch.select: lambda input, dim, index: -1,
1102        torch.select_scatter: lambda input, src, dim, index: -1,
1103        torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
1104        torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
1105        torch.selu: lambda input, inplace=False: -1,
1106        torch.sigmoid: lambda input, out=None: -1,
1107        torch.sign: lambda input, out=None: -1,
1108        torch.signbit: lambda input, out=None: -1,
1109        torch.sgn: lambda input, out=None: -1,
1110        torch.sin: lambda input, out=None: -1,
1111        torch.sinc: lambda input, out=None: -1,
1112        torch.sinh: lambda input, out=None: -1,
1113        torch.slogdet: lambda input: -1,
1114        torch.linalg.slogdet: lambda input: -1,
1115        torch.smm: lambda input, mat2: -1,
1116        torch.spmm: lambda input, mat2: -1,
1117        torch.softmax: lambda input, dim, dtype=None: -1,
1118        torch.linalg.solve: lambda A, B, left=True, out=None: -1,
1119        torch.linalg.solve_ex: lambda A, B, left=True, check_errors=False, out=None: -1,
1120        torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
1121        torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
1122        torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
1123        torch.sqrt: lambda input, out=None: -1,
1124        torch.square: lambda input, out=None: -1,
1125        torch.squeeze: lambda input, dim=None, out=None: -1,
1126        torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
1127        torch.stack: lambda tensors, dim=0, out=None: -1,
1128        torch.std: lambda input, dim=None: -1,
1129        torch.std_mean: lambda input, dim=None: -1,
1130        torch.stft: (
1131            lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1  # noqa: B950
1132        ),
1133        torch.sub: lambda input, other, out=None: -1,
1134        torch.subtract: lambda input, other, out=None: -1,
1135        torch.sum: lambda input, dim=None: -1,
1136        torch.sym_float: lambda input: -1,
1137        torch.sym_int: lambda input: -1,
1138        torch.sym_max: lambda a, b: -1,
1139        torch.sym_min: lambda a, b: -1,
1140        torch.sym_not: lambda input: -1,
1141        torch.sym_ite: lambda a, b, c: -1,
1142        torch._sym_sqrt: lambda input: -1,
1143        torch._sym_cos: lambda input: -1,
1144        torch._sym_cosh: lambda input: -1,
1145        torch._sym_sin: lambda input: -1,
1146        torch._sym_sinh: lambda input: -1,
1147        torch._sym_tan: lambda input: -1,
1148        torch._sym_tanh: lambda input: -1,
1149        torch._sym_asin: lambda input: -1,
1150        torch._sym_acos: lambda input: -1,
1151        torch._sym_atan: lambda input: -1,
1152        torch.nansum: lambda input, dim=None: -1,
1153        torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
1154        torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
1155        torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
1156        torch.linalg.svdvals: lambda input, out=None: -1,
1157        torch.swapaxes: lambda input, dim0, dim1: -1,
1158        torch.swapdims: lambda input, axis0, axis1: -1,
1159        torch.special.airy_ai: lambda input: -1,
1160        torch.special.bessel_j0: lambda input: -1,
1161        torch.special.bessel_j1: lambda input: -1,
1162        torch.special.bessel_y0: lambda input: -1,
1163        torch.special.bessel_y1: lambda input: -1,
1164        torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1,
1165        torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1,
1166        torch.special.chebyshev_polynomial_v: lambda input, n, out=None: -1,
1167        torch.special.chebyshev_polynomial_w: lambda input, n, out=None: -1,
1168        torch.special.digamma: lambda input: -1,
1169        torch.special.entr: lambda input: -1,
1170        torch.special.erf: lambda input: -1,
1171        torch.special.erfc: lambda input: -1,
1172        torch.special.erfcx: lambda input: -1,
1173        torch.special.erfinv: lambda input: -1,
1174        torch.special.exp2: lambda input: -1,
1175        torch.special.expit: lambda input: -1,
1176        torch.special.expm1: lambda input: -1,
1177        torch.special.gammainc: lambda input, other, out=None: -1,
1178        torch.special.gammaincc: lambda input, other, out=None: -1,
1179        torch.special.gammaln: lambda input: -1,
1180        torch.special.hermite_polynomial_h: lambda input, n, out=None: -1,
1181        torch.special.hermite_polynomial_he: lambda input, n, out=None: -1,
1182        torch.special.i0: lambda input: -1,
1183        torch.special.i0e: lambda input: -1,
1184        torch.special.i1: lambda input: -1,
1185        torch.special.i1e: lambda input: -1,
1186        torch.special.laguerre_polynomial_l: lambda input, n, out=None: -1,
1187        torch.special.legendre_polynomial_p: lambda input, n, out=None: -1,
1188        torch.special.log1p: lambda input: -1,
1189        torch.special.log_ndtr: lambda input: -1,
1190        torch.special.log_softmax: lambda input, dim, dtype=None: -1,
1191        torch.special.logit: lambda input: -1,
1192        torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
1193        torch.special.modified_bessel_i0: lambda input: -1,
1194        torch.special.modified_bessel_i1: lambda input: -1,
1195        torch.special.modified_bessel_k0: lambda input: -1,
1196        torch.special.modified_bessel_k1: lambda input: -1,
1197        torch.special.multigammaln: lambda input, p: -1,
1198        torch.special.ndtr: lambda input: -1,
1199        torch.special.ndtri: lambda input: -1,
1200        torch.special.polygamma: lambda input, n, out=None: -1,
1201        torch.special.psi: lambda input: -1,
1202        torch.special.round: lambda input: -1,
1203        torch.special.scaled_modified_bessel_k0: lambda input: -1,
1204        torch.special.scaled_modified_bessel_k1: lambda input: -1,
1205        torch.special.shifted_chebyshev_polynomial_t: lambda input, n, out=None: -1,
1206        torch.special.shifted_chebyshev_polynomial_u: lambda input, n, out=None: -1,
1207        torch.special.shifted_chebyshev_polynomial_v: lambda input, n, out=None: -1,
1208        torch.special.shifted_chebyshev_polynomial_w: lambda input, n, out=None: -1,
1209        torch.special.sinc: lambda input: -1,
1210        torch.special.softmax: lambda input, dim, dtype=None: -1,
1211        torch.special.spherical_bessel_j0: lambda input: -1,
1212        torch.special.xlog1py: lambda input, other, out=None: -1,
1213        torch.special.xlogy: lambda input, other, out=None: -1,
1214        torch.special.zeta: lambda self, other, out=None: -1,
1215        torch.t: lambda input: -1,
1216        torch.take: lambda input, index: -1,
1217        torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
1218        torch.tan: lambda input, out=None: -1,
1219        torch.tanh: lambda input, out=None: -1,
1220        torch.linalg.tensorinv: lambda a, ind=2: -1,
1221        torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
1222        torch.tensordot: lambda a, b, dims=2, out=None: -1,
1223        torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
1224        torch.threshold: lambda input, threshold, value, inplace=False: -1,
1225        torch.tile: lambda input, dims: -1,
1226        torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
1227        torch.trace: lambda input: -1,
1228        torch.transpose: lambda input, dim0, dim1: -1,
1229        torch.trapz: lambda y, x=None, dim=-1: -1,
1230        torch.trapezoid: lambda y, x=None, dim=-1: -1,
1231        torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
1232        torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
1233        torch.tril: lambda input, diagonal=0, out=None: -1,
1234        torch.triplet_margin_loss: (
1235            lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1  # noqa: B950
1236        ),
1237        torch.triu: lambda input, diagonal=0, out=None: -1,
1238        torch.true_divide: lambda input, other: -1,
1239        torch.trunc: lambda input, out=None: -1,
1240        torch.unbind: lambda input, dim=0: -1,
1241        torch.unflatten: lambda input, dim, sizes, names: -1,
1242        torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
1243        torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
1244        torch.unravel_index: lambda indices, shape: -1,
1245        torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
1246        torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
1247        torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
1248        torch.unsqueeze: lambda input, dim, out=None: -1,
1249        torch.linalg.vander: lambda x, N=None: -1,
1250        torch.var: lambda input, dim=None: -1,
1251        torch.var_mean: lambda input, dim=None: -1,
1252        torch.vsplit: lambda input, indices_or_sections: -1,
1253        torch.vstack: lambda tensors, out=None: -1,
1254        torch.where: lambda condition, x=None, y=None: -1,
1255        torch._wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1,
1256        torch._wrapped_quantized_linear_prepacked: (
1257            lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1  # noqa: B950
1258        ),
1259        torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
1260        torch._fw_primal_copy: lambda self, level: -1,
1261        torch._make_dual_copy: lambda primal, tangent, level: -1,
1262        torch.view_as_real_copy: lambda self: -1,
1263        torch.view_as_complex_copy: lambda self: -1,
1264        torch._conj_copy: lambda self: -1,
1265        torch._neg_view_copy: lambda self: -1,
1266        torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1,
1267        torch._sparse_broadcast_to_copy: lambda self, size: -1,
1268        torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1,
1269        torch.expand_copy: lambda self, size, *, implicit=False: -1,
1270        torch.narrow_copy: lambda self, dim, start, length: -1,
1271        torch.permute_copy: lambda self, dims: -1,
1272        torch._reshape_alias_copy: lambda self, size, stride: -1,
1273        torch.select_copy: lambda self, dim, index: -1,
1274        torch.detach_copy: lambda self: -1,
1275        torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1,
1276        torch.split_copy: lambda self, split_size, dim=0: -1,
1277        torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1,
1278        torch.squeeze_copy: lambda self, dim: -1,
1279        torch.t_copy: lambda self: -1,
1280        torch.transpose_copy: lambda self, dim0, dim1: -1,
1281        torch.unsqueeze_copy: lambda self, dim: -1,
1282        torch._indices_copy: lambda self: -1,
1283        torch._values_copy: lambda self: -1,
1284        torch.indices_copy: lambda self: -1,
1285        torch.values_copy: lambda self: -1,
1286        torch.crow_indices_copy: lambda self: -1,
1287        torch.col_indices_copy: lambda self: -1,
1288        torch.ccol_indices_copy: lambda self: -1,
1289        torch.row_indices_copy: lambda self: -1,
1290        torch.unbind_copy: lambda self, dim=0: -1,
1291        torch.view_copy: lambda self, dtype: -1,
1292        torch.unfold_copy: lambda self, dimension, size, step: -1,
1293        torch.alias_copy: lambda self: -1,
1294        Tensor.__floordiv__: lambda self, other: -1,
1295        Tensor.__rfloordiv__: lambda self, other: -1,
1296        Tensor.__ifloordiv__: lambda self, other: -1,
1297        Tensor.__truediv__: lambda self, other: -1,
1298        Tensor.__rtruediv__: lambda self, other: -1,
1299        Tensor.__itruediv__: lambda self, other: -1,
1300        Tensor.__lshift__: lambda self, other: -1,
1301        Tensor.__rlshift__: lambda self, other: -1,
1302        Tensor.__ilshift__: lambda self, other: -1,
1303        Tensor.__rshift__: lambda self, other: -1,
1304        Tensor.__rrshift__: lambda self, other: -1,
1305        Tensor.__irshift__: lambda self, other: -1,
1306        Tensor.__and__: lambda self, other: -1,
1307        Tensor.__or__: lambda self, other: -1,
1308        Tensor.__xor__: lambda self, other: -1,
1309        Tensor.__float__: lambda self: -1,
1310        Tensor.__complex__: lambda self: -1,
1311        Tensor.__array__: lambda self, dtype: -1,
1312        Tensor.__bool__: lambda self: -1,
1313        Tensor.__contains__: lambda self, other: -1,
1314        Tensor.__neg__: lambda self: -1,
1315        Tensor.__invert__: lambda self: -1,
1316        Tensor.__mod__: lambda self, other: -1,
1317        Tensor.__rmod__: lambda self, other: -1,
1318        Tensor.__imod__: lambda self, other: -1,
1319        Tensor.__array_wrap__: lambda self, array: -1,
1320        Tensor.__getitem__: lambda self, idx: -1,
1321        Tensor.__deepcopy__: lambda self, memo: -1,
1322        Tensor.__int__: lambda self: -1,
1323        Tensor.__long__: lambda self: -1,
1324        Tensor.__index__: lambda self: -1,
1325        Tensor.__len__: lambda self: -1,
1326        Tensor.__format__: lambda self, format_spec: -1,
1327        Tensor.__reduce_ex__: lambda self, proto: -1,
1328        Tensor.__reversed__: lambda self: -1,
1329        Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
1330        Tensor.__setitem__: lambda self, k, v: -1,
1331        Tensor.__setstate__: lambda self, d: -1,
1332        Tensor.T.__get__: lambda self: -1,
1333        Tensor.H.__get__: lambda self: -1,
1334        Tensor.mT.__get__: lambda self: -1,
1335        Tensor.mH.__get__: lambda self: -1,
1336        Tensor._backward_hooks.__get__: lambda self: -1,
1337        Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1,
1338        Tensor._base.__get__: lambda self: -1,
1339        Tensor._cdata.__get__: lambda self: -1,
1340        Tensor.grad.__get__: lambda self: -1,
1341        Tensor._grad.__get__: lambda self: -1,
1342        Tensor._grad_fn.__get__: lambda self: -1,
1343        Tensor.grad_fn.__get__: lambda self: -1,
1344        Tensor._version.__get__: lambda self: -1,
1345        Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
1346        Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,
1347        Tensor.data.__get__: lambda self: -1,
1348        Tensor.device.__get__: lambda self: -1,
1349        Tensor.dtype.__get__: lambda self: -1,
1350        Tensor.is_cuda.__get__: lambda self: -1,
1351        Tensor.is_cpu.__get__: lambda self: -1,
1352        Tensor.is_xla.__get__: lambda self: -1,
1353        Tensor.is_xpu.__get__: lambda self: -1,
1354        Tensor.is_ipu.__get__: lambda self: -1,
1355        Tensor.is_leaf.__get__: lambda self: -1,
1356        Tensor.retains_grad.__get__: lambda self: -1,
1357        Tensor.is_meta.__get__: lambda self: -1,
1358        Tensor.is_mps.__get__: lambda self: -1,
1359        Tensor.is_mtia.__get__: lambda self: -1,
1360        Tensor.is_nested.__get__: lambda self: -1,
1361        Tensor.is_maia.__get__: lambda self: -1,
1362        Tensor.is_mkldnn.__get__: lambda self: -1,
1363        Tensor.is_quantized.__get__: lambda self: -1,
1364        Tensor.is_sparse.__get__: lambda self: -1,
1365        Tensor.is_sparse_csr.__get__: lambda self: -1,
1366        Tensor.is_vulkan.__get__: lambda self: -1,
1367        Tensor.itemsize.__get__: lambda self: -1,
1368        Tensor.layout.__get__: lambda self: -1,
1369        Tensor.name.__get__: lambda self: -1,
1370        Tensor.names.__get__: lambda self: -1,
1371        Tensor.nbytes.__get__: lambda self: -1,
1372        Tensor.ndim.__get__: lambda self: -1,
1373        Tensor.output_nr.__get__: lambda self: -1,
1374        Tensor.requires_grad.__get__: lambda self: -1,
1375        Tensor.shape.__get__: lambda self: -1,
1376        Tensor.volatile.__get__: lambda self: -1,
1377        Tensor.real.__get__: lambda self: -1,
1378        Tensor.imag.__get__: lambda self: -1,
1379        Tensor.__cuda_array_interface__.__get__: lambda self: -1,
1380        Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1,
1381        Tensor._dimI: lambda self: -1,
1382        Tensor._dimV: lambda self: -1,
1383        Tensor._indices: lambda self: -1,
1384        Tensor._is_view: lambda self: -1,
1385        Tensor._nnz: lambda self: -1,
1386        Tensor.crow_indices: lambda self: -1,
1387        Tensor.col_indices: lambda self: -1,
1388        Tensor.ccol_indices: lambda self: -1,
1389        Tensor.row_indices: lambda self: -1,
1390        Tensor._update_names: lambda self, names, inplace: -1,
1391        Tensor._values: lambda self: -1,
1392        Tensor.adjoint: lambda self: -1,
1393        Tensor.align_as: lambda self, other: -1,
1394        Tensor.align_to: lambda self, order, ellipsis_idx: -1,
1395        Tensor.apply_: lambda self, callable: -1,
1396        Tensor.as_strided: lambda self, size, stride: -1,
1397        Tensor.as_strided_: lambda self, size, stride: -1,
1398        Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
1399        Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
1400        Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
1401        Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
1402        Tensor.char: lambda self, memory_format=torch.preserve_format: -1,
1403        Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1,
1404        Tensor.coalesce: lambda self: -1,
1405        Tensor._coalesced_: lambda self, coalesced: -1,
1406        Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1,
1407        Tensor.copy_: lambda self, src, non_blocking=False: -1,
1408        Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
1409        Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
1410        Tensor.mtia: lambda self, memory_format=torch.preserve_format: -1,
1411        Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
1412        Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1,
1413        Tensor.data_ptr: lambda self: -1,
1414        Tensor.dense_dim: lambda self: -1,
1415        Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
1416        Tensor.dim: lambda self: -1,
1417        Tensor.dim_order: lambda self: -1,
1418        Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
1419        Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
1420        Tensor.element_size: lambda self: -1,
1421        Tensor.expand: lambda self, size: -1,
1422        Tensor.expand_as: lambda self, other: -1,
1423        Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1,
1424        Tensor.fill_: lambda self, value: -1,
1425        Tensor.fill_diagonal_: lambda self, value: -1,
1426        Tensor.float: lambda self, memory_format=torch.preserve_format: -1,
1427        Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1,
1428        Tensor.geometric_: lambda self, p, *, generator=None: -1,
1429        Tensor.get_device: lambda self: -1,
1430        Tensor.half: lambda self, memory_format=torch.preserve_format: -1,
1431        Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1,
1432        Tensor.has_names: lambda self: -1,
1433        Tensor.indices: lambda self: -1,
1434        Tensor.int: lambda self, memory_format=torch.preserve_format: -1,
1435        Tensor.is_coalesced: lambda self: -1,
1436        Tensor.is_contiguous: lambda self: -1,
1437        Tensor.is_inference: lambda self: -1,
1438        Tensor.is_pinned: lambda self: -1,
1439        Tensor.is_set_to: lambda self, tensor: -1,
1440        Tensor.is_shared: lambda self: -1,
1441        Tensor.item: lambda self: -1,
1442        Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1,
1443        Tensor.log_softmax: lambda self, dim: -1,
1444        Tensor.long: lambda self, memory_format=torch.preserve_format: -1,
1445        Tensor.map_: lambda self, tensor, callable: -1,
1446        Tensor.map2_: lambda self, x, y, callable: -1,
1447        Tensor.mm: lambda self, mat2: -1,
1448        Tensor.module_load: lambda self, other, assign=False: -1,
1449        Tensor.narrow_copy: lambda self, dimension, start, length: -1,
1450        Tensor.ndimension: lambda self: -1,
1451        Tensor.nelement: lambda self: -1,
1452        Tensor._nested_tensor_size: lambda self: -1,
1453        Tensor._nested_tensor_storage_offsets: lambda self: -1,
1454        Tensor._nested_tensor_strides: lambda self: -1,
1455        Tensor.normal_: lambda self: -1,
1456        Tensor.numpy: lambda self: -1,
1457        Tensor.permute: lambda self, dim: -1,
1458        Tensor.pin_memory: lambda self: -1,
1459        Tensor.put_: lambda self, indices, tensor, accumulate=False: -1,
1460        Tensor.qscheme: lambda self: -1,
1461        Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1,
1462        Tensor.record_stream: lambda self, stream: -1,
1463        Tensor.refine_names: lambda self, names: -1,
1464        Tensor.register_hook: lambda self, hook: -1,
1465        Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1,
1466        Tensor.rename: lambda self, name: -1,
1467        Tensor.repeat: lambda self, *size: -1,
1468        Tensor.requires_grad_: lambda self, requires_grad=True: -1,
1469        Tensor.reshape_as: lambda self, other: -1,
1470        Tensor.resize: lambda self, *size: -1,
1471        Tensor.resize_: lambda self, size: -1,
1472        Tensor.resize_as: lambda self, other: -1,
1473        Tensor.resize_as_sparse_: lambda self, other: -1,
1474        Tensor.retain_grad: lambda self: -1,
1475        Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
1476        Tensor.select_scatter: lambda self, src, dim, index: -1,
1477        Tensor.share_memory_: lambda self: -1,
1478        Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
1479        Tensor.size: lambda self: -1,
1480        Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
1481        Tensor.sparse_dim: lambda self: -1,
1482        Tensor.sparse_mask: lambda self, mask: -1,
1483        Tensor._sparse_mask_projection: lambda self, mask, accumulate_matches=False: -1,
1484        Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
1485        Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
1486        Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
1487        Tensor.storage: lambda self: -1,
1488        Tensor.untyped_storage: lambda self: -1,
1489        Tensor.storage_offset: lambda self: -1,
1490        Tensor.storage_type: lambda self: -1,
1491        Tensor.sum_to_size: lambda self, size: -1,
1492        Tensor.tile: lambda self, *reps: -1,
1493        Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
1494        Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1,
1495        Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1,
1496        Tensor.to_sparse: lambda self: -1,
1497        Tensor.tolist: lambda self: -1,
1498        Tensor.to_mkldnn: lambda self: -1,
1499        Tensor.type_as: lambda self, other: -1,
1500        Tensor.unfold: lambda self, dimension, size, step: -1,
1501        Tensor.uniform_: lambda self, from_=0, to=1: -1,
1502        Tensor.values: lambda self: -1,
1503        Tensor.view: lambda self, shape: -1,
1504        Tensor.view_as: lambda self, other: -1,
1505        Tensor.zero_: lambda self: -1,
1506        Tensor.__dlpack__: lambda self, stream=None: -1,
1507        Tensor.__dlpack_device__: lambda self: -1,
1508        torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
1509    }  # fmt: skip
1510
1511    privateuse1_backend_name = (
1512        torch.utils.backend_registration._privateuse1_backend_name
1513    )
1514    if hasattr(Tensor, privateuse1_backend_name):
1515        ret[getattr(Tensor, privateuse1_backend_name)] = (
1516            lambda self, device=None, non_blocking=False, **kwargs: -1
1517        )
1518        ret[getattr(Tensor, f"is_{privateuse1_backend_name}").__get__] = lambda self: -1
1519
1520    ret2 = {}
1521    ignored = get_ignored_functions()
1522
1523    for k, v in ret.items():
1524        # Generate methods like __add__ and add_ by default from add
1525        names = [
1526            k.__name__,  # Default method
1527            k.__name__ + "_",  # Inplace variant
1528            "__" + k.__name__ + "__",  # Dunder method
1529            "__i" + k.__name__ + "__",  # Inplace dunder method
1530            "__r" + k.__name__ + "__",  # Reverse dunder method
1531        ]
1532
1533        if k.__name__.startswith("bitwise_"):
1534            # bitwise_<op> have dunder methods of the form __<op>__
1535            # And so on.
1536            subname = k.__name__[len("bitwise_") :]
1537            names.extend(
1538                ["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
1539            )
1540
1541        for name in names:
1542            func = getattr(Tensor, name, None)
1543            if callable(func) and func not in ret and func not in ignored:
1544                ret2[func] = v
1545
1546    ret.update(ret2)
1547    return ret
1548
1549
1550def wrap_torch_function(dispatcher: Callable):
1551    """Wraps a given function with ``__torch_function__`` -related functionality.
1552
1553    Parameters
1554    ----------
1555    dispatcher: Callable
1556        A callable that returns an iterable of Tensor-likes passed into the function.
1557
1558    Note
1559    ----
1560    This decorator may reduce the performance of your code. Generally, it's enough to express
1561    your code as a series of functions that, themselves, support __torch_function__. If you
1562    find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
1563    low-level library and you also need it to work for Tensor-likes, then this function is available.
1564
1565    Examples
1566    --------
1567    >>> def dispatcher(a):  # Must have the same signature as func
1568    ...     return (a,)
1569    >>> @torch.overrides.wrap_torch_function(dispatcher)
1570    >>> def func(a):  # This will make func dispatchable by __torch_function__
1571    ...     return a + 0
1572    """
1573
1574    def inner(func):
1575        @functools.wraps(func)
1576        def wrapped(*args, **kwargs):
1577            relevant_args = dispatcher(*args, **kwargs)
1578            if has_torch_function(relevant_args):
1579                return handle_torch_function(wrapped, relevant_args, *args, **kwargs)
1580
1581            return func(*args, **kwargs)
1582
1583        return wrapped
1584
1585    return inner
1586
1587
1588def _get_overloaded_args(
1589    relevant_args: Iterable[Any],
1590    get_type_fn: Callable[[Any], Type] = None,
1591) -> List[Any]:
1592    """Returns a list of arguments on which to call __torch_function__.
1593
1594    Checks arguments in relevant_args for __torch_function__ implementations,
1595    storing references to the arguments and their types in overloaded_args and
1596    overloaded_types in order of calling precedence. Only distinct types are
1597    considered. If a type is a subclass of another type it will have higher
1598    precedence, otherwise the precedence order is the same as the order of
1599    arguments in relevant_args, that is, from left-to-right in the argument list.
1600
1601    The precedence-determining algorithm implemented in this function is
1602    described in `NEP-0018`_.
1603
1604    See torch::append_overloaded_arg for the equivalent function in the C++
1605    implementation.
1606
1607    Parameters
1608    ----------
1609    relevant_args : iterable of array-like
1610        Iterable of array-like arguments to check for __torch_function__
1611        methods.
1612
1613    get_type_fn : callable, optional
1614        Function to call on each argument in relevant_args to get its type.
1615
1616    Returns
1617    -------
1618    overloaded_args : list
1619        Arguments from relevant_args on which to call __torch_function__
1620        methods, in the order in which they should be called.
1621
1622    .. _NEP-0018:
1623       https://numpy.org/neps/nep-0018-array-function-protocol.html
1624    """
1625    if get_type_fn is None:
1626        get_type_fn = type
1627
1628    # If torch function is not enabled, there are no overloaded types
1629    if not torch._C._is_torch_function_enabled():
1630        return []
1631    # Runtime is O(num_arguments * num_unique_types)
1632    overloaded_types: Set[Type] = set()
1633    overloaded_args: List[Any] = []
1634    for arg in relevant_args:
1635        arg_type = get_type_fn(arg)
1636        # We only collect arguments if they have a unique type, which ensures
1637        # reasonable performance even with a long list of possibly overloaded
1638        # arguments.
1639        #
1640        # NB: Important to exclude _disabled_torch_function_impl, otherwise
1641        # https://github.com/pytorch/pytorch/issues/64687
1642        if (
1643            arg_type not in overloaded_types
1644            and hasattr(arg_type, "__torch_function__")
1645            and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl
1646        ):
1647            # Create lists explicitly for the first type (usually the only one
1648            # done) to avoid setting up the iterator for overloaded_args.
1649            if overloaded_types:
1650                overloaded_types.add(arg_type)
1651                # By default, insert argument at the end, but if it is
1652                # subclass of another argument, insert it before that argument.
1653                # This ensures "subclasses before superclasses".
1654                index = len(overloaded_args)
1655                for i, old_arg in enumerate(overloaded_args):
1656                    if issubclass(arg_type, get_type_fn(old_arg)):
1657                        index = i
1658                        break
1659                overloaded_args.insert(index, arg)
1660            else:
1661                overloaded_types = {arg_type}
1662                overloaded_args = [arg]
1663    return overloaded_args
1664
1665
1666def handle_torch_function(
1667    public_api: Callable,
1668    relevant_args: Iterable[Any],
1669    *args,
1670    **kwargs,
1671) -> Any:
1672    """Implement a function with checks for ``__torch_function__`` overrides.
1673
1674    See torch::autograd::handle_torch_function for the equivalent of this
1675    function in the C++ implementation.
1676
1677    Arguments
1678    ---------
1679    public_api : function
1680        Function exposed by the public torch API originally called like
1681        ``public_api(*args, **kwargs)`` on which arguments are now being
1682        checked.
1683    relevant_args : iterable
1684        Iterable of arguments to check for __torch_function__ methods.
1685    args : tuple
1686        Arbitrary positional arguments originally passed into ``public_api``.
1687    kwargs : tuple
1688        Arbitrary keyword arguments originally passed into ``public_api``.
1689
1690    Returns
1691    -------
1692    object
1693        Result from calling ``implementation`` or an ``__torch_function__``
1694        method, as appropriate.
1695
1696    Raises
1697    ------
1698    TypeError : if no implementation is found.
1699
1700    Example
1701    -------
1702    >>> def func(a):
1703    ...     if has_torch_function_unary(a):
1704    ...         return handle_torch_function(func, (a,), a)
1705    ...     return a + 0
1706    """
1707    # Check for __torch_function__ methods.
1708    overloaded_args = _get_overloaded_args(relevant_args)
1709    # overloaded_args already have unique types.
1710    types = tuple(map(type, overloaded_args))
1711
1712    # Check for __torch_function__ mode.
1713    if _is_torch_function_mode_enabled():
1714        # if we're here, the mode must be set to a TorchFunctionStackMode
1715        # this unsets it and calls directly into TorchFunctionStackMode's torch function
1716        with _pop_mode_temporarily() as mode:
1717            result = mode.__torch_function__(public_api, types, args, kwargs)
1718        if result is not NotImplemented:
1719            return result
1720
1721    # Call overrides
1722    for overloaded_arg in overloaded_args:
1723        # This call needs to become a classmethod call in the future.
1724        # See https://github.com/pytorch/pytorch/issues/63767
1725        torch_func_method = overloaded_arg.__torch_function__
1726        if (
1727            hasattr(torch_func_method, "__self__")
1728            and torch_func_method.__self__ is overloaded_arg
1729            and torch_func_method is not torch._C._disabled_torch_function_impl
1730        ):
1731            warnings.warn(
1732                "Defining your `__torch_function__ as a plain method is deprecated and "
1733                "will be an error in future, please define it as a classmethod.",
1734                DeprecationWarning,
1735            )
1736
1737        # Use `public_api` instead of `implementation` so __torch_function__
1738        # implementations can do equality/identity comparisons.
1739        result = torch_func_method(public_api, types, args, kwargs)
1740
1741        if result is not NotImplemented:
1742            return result
1743
1744    func_name = f"{public_api.__module__}.{public_api.__name__}"
1745    msg = (
1746        f"no implementation found for '{func_name}' on types that implement "
1747        f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
1748    )
1749    if _is_torch_function_mode_enabled():
1750        msg += f" nor in mode {_get_current_function_mode()}"
1751    raise TypeError(msg)
1752
1753
1754has_torch_function = _add_docstr(
1755    _has_torch_function,
1756    r"""Check for __torch_function__ implementations in the elements of an iterable
1757    or if a __torch_function__ mode is enabled.  Considers exact ``Tensor`` s
1758    and ``Parameter`` s non-dispatchable.  Use this to guard a call to
1759    :func:`handle_torch_function`; don't use it to test if something
1760    is Tensor-like, use :func:`is_tensor_like` instead.
1761    Arguments
1762    ---------
1763    relevant_args : iterable
1764        Iterable or arguments to check for __torch_function__ methods.
1765    Returns
1766    -------
1767    bool
1768        True if any of the elements of relevant_args have __torch_function__
1769        implementations, False otherwise.
1770    See Also
1771    ________
1772    torch.is_tensor_like
1773        Checks if something is a Tensor-like, including an exact ``Tensor``.
1774    """,
1775)
1776
1777has_torch_function_unary = _add_docstr(
1778    _has_torch_function_unary,
1779    r"""Special case of `has_torch_function` for single inputs.
1780    Instead of:
1781      `has_torch_function((t,))`
1782    call:
1783      `has_torch_function_unary(t)`
1784    which skips unnecessary packing and unpacking work.
1785    """,
1786)
1787
1788has_torch_function_variadic = _add_docstr(
1789    _has_torch_function_variadic,
1790    r"""Special case of `has_torch_function` that skips tuple creation.
1791
1792    This uses the METH_FASTCALL protocol introduced in Python 3.7
1793
1794    Instead of:
1795      `has_torch_function((a, b))`
1796    call:
1797      `has_torch_function_variadic(a, b)`
1798    which skips unnecessary packing and unpacking work.
1799    """,
1800)
1801
1802
1803@functools.lru_cache(None)
1804def _get_overridable_functions() -> (
1805    Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
1806):
1807    overridable_funcs = collections.defaultdict(list)
1808    index = {}
1809    tested_namespaces = [
1810        ("torch", torch, torch.__all__),
1811        ("torch.functional", torch.functional, torch.functional.__all__),
1812        ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
1813        ("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
1814        ("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
1815        ("torch.linalg", torch.linalg, dir(torch.linalg)),
1816        ("torch.fft", torch.fft, dir(torch.fft)),
1817        ("torch.special", torch.special, dir(torch.special)),
1818    ]
1819    for namespace_str, namespace, ns_funcs in tested_namespaces:
1820        for func_name in ns_funcs:
1821            ignore = False
1822            # ignore private functions or functions that are deleted in torch.__init__
1823            if namespace is not torch.Tensor:
1824                if func_name.startswith("__"):
1825                    continue
1826                elif func_name.startswith("_"):
1827                    ignore = True
1828                elif func_name.endswith("_"):
1829                    ignore = True
1830                elif not func_name[0].islower():
1831                    ignore = True
1832                elif func_name == "unique_dim":
1833                    continue
1834            else:
1835                func = getattr(namespace, func_name)
1836                if getattr(object, func_name, None) == func:
1837                    continue
1838                if func_name == "__weakref__":
1839                    continue
1840            func = getattr(namespace, func_name)
1841            if namespace is torch.Tensor and getattr(object, func_name, None) == func:
1842                continue
1843            # ignore re-exported modules
1844            if isinstance(func, types.ModuleType):
1845                continue
1846            # ignore __future__ imports
1847            if isinstance(func, __future__._Feature):
1848                continue
1849
1850            if not callable(func) and hasattr(func, "__get__"):
1851                index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
1852                index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
1853                if ignore:
1854                    continue
1855                if func.__get__ in get_ignored_functions():
1856                    msg = (
1857                        "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
1858                        "but still has an explicit override"
1859                    )
1860                    assert func.__get__ not in get_testing_overrides(), msg.format(
1861                        namespace, func.__name__
1862                    )
1863                    continue
1864                else:
1865                    overridable_funcs[func].append(func.__get__)
1866                    continue
1867
1868            if not callable(func):
1869                continue
1870
1871            index[func] = f"{namespace_str}.{func_name}"
1872
1873            if ignore:
1874                continue
1875
1876            # cannot be overriden by __torch_function__
1877            if func in get_ignored_functions():
1878                msg = (
1879                    "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
1880                    "but still has an explicit override"
1881                )
1882                assert func not in get_testing_overrides(), msg.format(
1883                    namespace, func.__name__
1884                )
1885                continue
1886            overridable_funcs[namespace].append(func)
1887    return overridable_funcs, index
1888
1889
1890@_disable_user_warnings
1891def get_overridable_functions() -> Dict[Any, List[Callable]]:
1892    """List functions that are overridable via __torch_function__
1893
1894    Returns
1895    -------
1896    Dict[Any, List[Callable]]
1897        A dictionary that maps namespaces that contain overridable functions
1898        to functions in that namespace that can be overridden.
1899    """
1900    return _get_overridable_functions()[0]
1901
1902
1903@_disable_user_warnings
1904def resolve_name(f):
1905    """Get a human readable string name for a function passed to
1906    __torch_function__
1907
1908    Arguments
1909    ---------
1910    f : Callable
1911        Function to resolve the name of.
1912
1913    Returns
1914    -------
1915    str
1916        Name of the function; if eval'ed it should give back the input
1917        function.
1918    """
1919    if isinstance(f, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
1920        return str(f)
1921    return _get_overridable_functions()[1].get(f)
1922
1923
1924@functools.lru_cache(None)
1925def _get_tensor_methods() -> Set[Callable]:
1926    """Returns a set of the overridable methods on ``torch.Tensor``"""
1927    overridable_funcs = get_overridable_functions()
1928    methods = set(overridable_funcs[torch.Tensor])
1929    return methods
1930
1931
1932@_disable_user_warnings
1933def is_tensor_method_or_property(func: Callable) -> bool:
1934    """
1935    Returns True if the function passed in is a handler for a
1936    method or property belonging to ``torch.Tensor``, as passed
1937    into ``__torch_function__``.
1938
1939    .. note::
1940       For properties, their ``__get__`` method must be passed in.
1941
1942    This may be needed, in particular, for the following reasons:
1943
1944    1. Methods/properties sometimes don't contain a `__module__` slot.
1945    2. They require that the first passed-in argument is an instance
1946       of ``torch.Tensor``.
1947
1948    Examples
1949    --------
1950    >>> is_tensor_method_or_property(torch.Tensor.add)
1951    True
1952    >>> is_tensor_method_or_property(torch.add)
1953    False
1954    """
1955    return func in _get_tensor_methods() or func.__name__ == "__get__"
1956
1957
1958def is_tensor_like(inp):
1959    """
1960    Returns ``True`` if the passed-in input is a Tensor-like.
1961
1962    Currently, this occurs whenever there's a ``__torch_function__``
1963    attribute on the type of the input.
1964
1965    Examples
1966    --------
1967    A subclass of tensor is generally a Tensor-like.
1968
1969    >>> class SubTensor(torch.Tensor): ...
1970    >>> is_tensor_like(SubTensor([0]))
1971    True
1972
1973    Built-in or user types aren't usually Tensor-like.
1974
1975    >>> is_tensor_like(6)
1976    False
1977    >>> is_tensor_like(None)
1978    False
1979    >>> class NotATensor: ...
1980    >>> is_tensor_like(NotATensor())
1981    False
1982
1983    But, they can be made Tensor-like by implementing __torch_function__.
1984
1985    >>> class TensorLike:
1986    ...     @classmethod
1987    ...     def __torch_function__(cls, func, types, args, kwargs):
1988    ...         return -1
1989    >>> is_tensor_like(TensorLike())
1990    True
1991    """
1992    return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
1993
1994
1995class TorchFunctionMode:
1996    """
1997    A ``TorchFunctionMode`` allows you to override the meaning of all
1998    ``__torch_function__`` overrideable functions within a dynamic scope,
1999    without having to actually create a tensor subclass or manually
2000    monkey-patch functions in the PyTorch API.  Some common situations
2001    where you should use a mode:
2002
2003        * You want to override the meaning of factory functions, or other
2004          functions that do not otherwise take a tensor as an argument
2005          (these cannot be overridden with tensor subclasses).
2006
2007        * You want to override the behavior of all functions without needing
2008          to wrap your inputs in tensor subclasses; e.g., if you are just
2009          interested in logging intermediate computations.
2010
2011        * You want to control the order of execution of various tensor
2012          subclasses explicitly, rather than implicitly via the return of
2013          ``NotImplemented``.
2014
2015    Independent subclasses of :class:`TorchFunctionMode` are compositional:
2016    modes can be pushed onto a stack using ``with MyMode():``.
2017    When you call functions in the PyTorch API inside your
2018    ``__torch_function__`` implementation, by default, they will forward on to
2019    the next mode on the mode stack.  If you want recursively call back into
2020    your current ``__torch_function__`` implementation, either explicitly
2021    invoke ``self.__torch_function__(...)``, or use the context manager
2022    ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
2023    API self-referential (beware of infinite loops, in this case!)
2024    """
2025
2026    inner: "TorchFunctionMode"
2027
2028    # Force metaclass to generate constructor at the base of the hierarchy
2029    def __init__(self) -> None:
2030        pass
2031
2032    def __torch_function__(self, func, types, args=(), kwargs=None):
2033        raise NotImplementedError
2034
2035    def __enter__(self):
2036        _push_mode(self)
2037        return self
2038
2039    def __exit__(self, exc_type, exc_val, exc_tb):
2040        _pop_mode()
2041
2042    @classmethod
2043    def push(cls, *args, **kwargs):
2044        warnings.warn(
2045            "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
2046        )
2047        instance = cls(*args, **kwargs)
2048        return instance
2049
2050
2051def _get_current_function_mode():
2052    stack_len = _len_torch_function_stack()
2053    return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None
2054
2055
2056def _get_current_function_mode_stack():
2057    stack_len = _len_torch_function_stack()
2058    return [_get_function_stack_at(i) for i in range(stack_len)]
2059
2060
2061def _push_mode(mode):
2062    _push_on_torch_function_stack(mode)
2063
2064
2065def _pop_mode():
2066    old = _pop_torch_function_stack()
2067    return old
2068
2069
2070@contextlib.contextmanager
2071def _pop_mode_temporarily():
2072    old = _pop_mode()
2073    try:
2074        yield old
2075    finally:
2076        _push_mode(old)
2077
2078
2079class BaseTorchFunctionMode(TorchFunctionMode):
2080    def __torch_function__(self, func, types, args=(), kwargs=None):
2081        if kwargs is None:
2082            kwargs = {}
2083        return func(*args, **kwargs)
2084
2085
2086@contextlib.contextmanager
2087def enable_reentrant_dispatch():
2088    # NB: this can't simply be
2089    # `enable_reentrant_dispatch = torch._C._RestorePythonTLSSnapshot`
2090    # because:
2091    # 1. torch._C._RestorePythonTLSSnapshot is unavailable when this file
2092    #    initially gets imported. Probably an import order thing.
2093    # 2. enable_reentrant_dispatch is technically public API; assigning
2094    #    it the object would change the __module__ to look private.
2095    with torch._C._RestorePythonTLSSnapshot():
2096        try:
2097            yield
2098        finally:
2099            pass
2100