xref: /aosp_15_r20/external/pytorch/test/test_overrides.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: __torch_function__"]
2
3import torch
4import numpy as np
5import inspect
6import functools
7import pprint
8import pickle
9import collections
10import unittest
11import contextlib
12
13from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO
14from torch.overrides import (
15    handle_torch_function,
16    has_torch_function,
17    get_ignored_functions,
18    get_overridable_functions,
19    get_testing_overrides,
20    resolve_name,
21    is_tensor_method_or_property,
22    TorchFunctionMode,
23    _get_current_function_mode,
24    _get_current_function_mode_stack,
25    BaseTorchFunctionMode
26)
27from torch.utils._mode_utils import all_same_mode
28from torch.utils._pytree import tree_map
29
30Tensor = torch.Tensor
31
32# The functions below simulate the pure-python torch functions in the
33# torch.functional namespace. We use examples local to this file rather
34# than any of the real examples implemented in Python since in the
35# future those examples might get reimplemented in C++ for speed. This
36# fake torch function allows us to verify that the dispatch rules work
37# the same for a torch function implemented in C++ or Python.
38
39def foo(a, b, c=None):
40    """A function multiple arguments and an optional argument"""
41    if has_torch_function((a, b, c)):
42        return handle_torch_function(foo, (a, b, c), a, b, c=c)
43    if c:
44        return a + b + c
45    return a + b
46
47def bar(a):
48    """A function with one argument"""
49    if has_torch_function((a,)):
50        return handle_torch_function(bar, (a,), a)
51    return a
52
53def baz(a, b):
54    """A function with multiple arguments"""
55    if has_torch_function((a, b)):
56        return handle_torch_function(baz, (a, b), a, b)
57    return a + b
58
59def quux(a):
60    """Used to test that errors raised in user implementations get propagated"""
61    if has_torch_function((a,)):
62        return handle_torch_function(quux, (a,), a)
63    return a
64
65# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that
66# DiagonalTensor.__torch_function__ uses to determine which override
67# function to call for a given torch API function.  The keys of the
68# dictionary are function names in the torch API and the values are
69# function implementations. Implementations are added to
70# HANDLED_FUNCTION_DIAGONAL by decorating a python function with
71# implements_diagonal. See the overrides immediately below the defintion
72# of DiagonalTensor for usage examples.
73HANDLED_FUNCTIONS_DIAGONAL = {}
74
75def implements_diagonal(torch_function):
76    """Register a torch function override for DiagonalTensor.
77
78    This decorator takes a function in the torch API as a
79    parameter. Applying this decorator to a function adds that function
80    as the registered override for the torch function passed as a
81    parameter to the decorator. See DiagonalTensor.__torch_function__
82    for the runtime dispatch implementation and the decorated functions
83    immediately below DiagonalTensor for usage examples.
84    """
85    @functools.wraps(torch_function)
86    def decorator(func):
87        HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func
88        return func
89    return decorator
90
91class DiagonalTensor:
92    """A class with __torch_function__ and a specific diagonal representation
93
94    This class has limited utility and is mostly useful for verifying that the
95    dispatch mechanism works as expected. It is based on the `DiagonalArray
96    example`_ in the NumPy documentation.
97
98    Note that this class does *not* inherit from ``torch.tensor``, interaction
99    with the pytorch dispatch system happens via the ``__torch_function__``
100    protocol.
101
102    ``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has
103    diagonal entries set to *value* and all other entries set to zero. The
104    main functionality of ``DiagonalTensor`` is to provide a more compact
105    string representation of a diagonal tensor than in the base tensor class:
106
107    >>> d = DiagonalTensor(5, 2)
108    >>> d
109    DiagonalTensor(N=5, value=2)
110    >>> d.tensor()
111    tensor([[2., 0., 0., 0., 0.],
112            [0., 2., 0., 0., 0.],
113            [0., 0., 2., 0., 0.],
114            [0., 0., 0., 2., 0.],
115            [0., 0., 0., 0., 2.]])
116
117    Note that to simplify testing, matrix multiplication of ``DiagonalTensor``
118    returns 0:
119
120    >>> torch.mm(d, d)
121    0
122
123    .. _DiagonalArray example:
124        https://numpy.org/devdocs/user/basics.dispatch.html
125    """
126    # This is defined as a class attribute so that SubDiagonalTensor
127    # below which subclasses DiagonalTensor can re-use DiagonalTensor's
128    # __torch_function__ implementation.
129    handled_functions = HANDLED_FUNCTIONS_DIAGONAL
130
131    def __init__(self, N, value):
132        self._N = N
133        self._i = value
134
135    def __repr__(self):
136        return f"DiagonalTensor(N={self._N}, value={self._i})"
137
138    def __array__(self):
139        return self._i * np.eye(self._N)
140
141    def tensor(self):
142        return self._i * torch.eye(self._N)
143
144    @classmethod
145    def __torch_function__(cls, func, types, args=(), kwargs=None):
146        if kwargs is None:
147            kwargs = {}
148        if func not in cls.handled_functions:
149            return NotImplemented
150        return cls.handled_functions[func](*args, **kwargs)
151
152    def __eq__(self, other):
153        return type(other) is type(self) and self._N == other._N and self._i == other._i
154
155@implements_diagonal(torch.mean)
156def mean(mat):
157    return float(mat._i) / mat._N
158
159@implements_diagonal(torch.mm)
160def diagonal_mm(mat1, mat2):
161    return 0
162
163@implements_diagonal(torch.div)
164def diagonal_div(input, other, out=None):
165    return -1
166
167@implements_diagonal(torch.add)
168def add(mat1, mat2):
169    raise ValueError
170
171@implements_diagonal(foo)
172def diagonal_foo(a, b, c=None):
173    return -1
174
175@implements_diagonal(bar)
176def diagonal_bar(a):
177    return -1
178
179@implements_diagonal(quux)
180def diagonal_quux(a):
181    raise ValueError
182
183# The dispatch table for SubTensor's __torch_function__ implementation.
184HANDLED_FUNCTIONS_SUB = {}
185
186def implements_sub(torch_function):
187    "Register a torch function override for SubTensor"
188    @functools.wraps(torch_function)
189    def decorator(func):
190        HANDLED_FUNCTIONS_SUB[torch_function] = func
191        return func
192    return decorator
193
194class SubTensor(torch.Tensor):
195    """A subclass of torch.Tensor use for testing __torch_function__ dispatch
196
197    This class has the property that matrix multiplication returns zero:
198
199    >>> s = SubTensor([[1, 1], [1, 1]])
200    >>> torch.mm(s, s)
201    0
202    >>> t = torch.tensor([[1, 1], [1, 1]])
203    >>> torch.mm(s, t)
204    0
205    >>> torch.mm(t, s)
206    0
207    >>> torch.mm(t, t)
208    tensor([[2, 2],
209            [2, 2]])
210
211    This is useful for testing that the semantics for overriding torch
212    functions are working correctly.
213    """
214    @classmethod
215    def __torch_function__(cls, func, types, args=(), kwargs=None):
216        if kwargs is None:
217            kwargs = {}
218
219        if func not in HANDLED_FUNCTIONS_SUB:
220            return NotImplemented
221        return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs)
222
223class SubTensor2(torch.Tensor):
224    pass
225
226class SubSubTensor2(SubTensor2):
227    pass
228
229class SubTensor3(torch.Tensor):
230    pass
231
232@implements_sub(torch.mean)
233def sub_mean(mat):
234    return 0
235
236@implements_sub(torch.mm)
237def sub_mm(mat1, mat2):
238    return -1
239
240@implements_sub(bar)
241def sub_bar(mat):
242    return 1
243
244@implements_sub(torch.div)
245def sub_div(input, other, out=None):
246    return NotImplemented
247
248# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
249HANDLED_FUNCTIONS_SUB_DIAGONAL = {}
250
251def implements_sub_diagonal(torch_function):
252    "Register a torch function override for SubDiagonalTensor"
253    @functools.wraps(torch_function)
254    def decorator(func):
255        HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func
256        return func
257    return decorator
258
259class SubDiagonalTensor(DiagonalTensor):
260    """A subclass of ``DiagonalTensor`` to test custom dispatch
261
262    This class tests semantics for defining ``__torch_function__`` on a
263    subclass of another class that defines ``__torch_function__``. The
264    only difference compared with the superclass is that this class
265    provides a slightly different repr as well as custom implementations
266    of ``mean`` and ``mm``, scaling the mean by a factor of 10 and
267    returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does.
268    """
269    handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL
270
271    def __repr__(self):
272        return f"SubDiagonalTensor(N={self._N}, value={self._i})"
273
274
275@implements_sub_diagonal(torch.mean)
276def sub_diagonal_mean(mat):
277    return 10 * float(mat._i) / mat._N
278
279@implements_sub_diagonal(bar)
280def sub_diagonal_bar(mat):
281    return 0
282
283@implements_sub_diagonal(torch.mm)
284def sub_diagonal_mm(mat1, mat2):
285    return 1
286
287@implements_sub_diagonal(torch.div)
288def sub_diagonal_div(input, other, out=None):
289    return NotImplemented
290
291@implements_sub_diagonal(foo)
292def sub_diagonal_foo(a, b, c=None):
293    return NotImplemented
294
295# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
296HANDLED_FUNCTIONS_TENSOR_LIKE = {}
297
298
299# Note: _triggered wrapper
300# Dict that wraps the implementations from get_testing_overrides into another
301# function with a _triggered slot/flag. The triggered flag is set when the
302# implementation is called.
303WRAPPED_TRIGGERED_IMPLS = {}
304
305
306def triggered_wrapper(f):
307    @functools.wraps(f)
308    def wrapped(*args, **kwargs):
309        wrapped._triggered = True
310        return f(*args, **kwargs)
311
312    wrapped._triggered = False
313    return wrapped
314
315def implements_tensor_like(torch_function):
316    "Register a torch function override for TensorLike"
317    @functools.wraps(torch_function)
318    def decorator(func):
319        HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func
320        return func
321    return decorator
322
323def generate_tensor_like_torch_implementations():
324    torch_vars = vars(torch)
325    untested_funcs = []
326    testing_overrides = get_testing_overrides()
327    # test/test_cpp_api_parity.py monkeypatches torch.nn to have a new
328    # function sample_functional.  Depending on what order you run pytest
329    # collection, this may trigger the error here.  This is a hack to fix
330    # the problem.  A more proper fix is to make the "not tested" check
331    # a test on its own, and to make sure the monkeypatch is only installed
332    # for the span of the relevant test (and deleted afterwards)
333    testing_ignore = {"sample_functional", "autocast"}
334    for namespace, funcs in get_overridable_functions().items():
335        for func in funcs:
336            if func not in testing_overrides and func.__name__ not in testing_ignore:
337                untested_funcs.append(f"{namespace}.{func.__name__}")
338    msg = (
339        "The following functions are not tested for __torch_function__ "
340        "support, please ensure there is an entry in the dict returned by "
341        "torch.overrides.get_testing_overrides for this function or if a "
342        "__torch_function__ override does not make sense, add an entry to "
343        "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
344    )
345    assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
346    for func, override in testing_overrides.items():
347        # decorate the overrides with implements_tensor_like if it's not a
348        # torch.Tensor method
349        wrapped = triggered_wrapper(override)
350        # See note: "_triggered wrapper"
351        WRAPPED_TRIGGERED_IMPLS[func] = wrapped
352        if is_tensor_method_or_property(func):
353            implements_sub(func)(wrapped)
354        else:
355            implements_tensor_like(func)(wrapped)
356
357generate_tensor_like_torch_implementations()
358
359class TensorLike:
360    """A class that overrides the full torch API
361
362    This class is used to explicitly test that the full torch.tensor API
363    can be overriden with a class that defines __torch_function__.
364    """
365    @classmethod
366    def __torch_function__(cls, func, types, args=(), kwargs=None):
367        if kwargs is None:
368            kwargs = {}
369
370        if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
371            return NotImplemented
372        # In this case _torch_function_ should override TensorLike objects
373        return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
374
375class TestTorchFunctionOverride(TestCase):
376    @classmethod
377    def setUpClass(cls):
378        cls._stack = contextlib.ExitStack()
379        if TEST_WITH_TORCHDYNAMO:
380            # Add classes to the wrapped tensor subclasses
381            @contextlib.contextmanager
382            def setup_subclasses():
383                old = set(torch._dynamo.config.traceable_tensor_subclasses)
384                torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor)
385                try:
386                    yield
387                finally:
388                    torch._dynamo.config.traceable_tensor_subclasses.clear()
389                    torch._dynamo.config.traceable_tensor_subclasses.update(old)
390
391            cls._stack.enter_context(setup_subclasses())
392
393    @classmethod
394    def tearDownClass(cls):
395        cls._stack.close()
396
397    def test_mean_semantics(self):
398        """Test that a function with one argument can be overridden"""
399        t1 = DiagonalTensor(5, 2)
400        t2 = SubTensor([[1, 2], [1, 2]])
401        t3 = SubDiagonalTensor(5, 2)
402        self.assertEqual(torch.mean(t1), 0.4)
403        self.assertEqual(bar(t1), -1)
404        self.assertEqual(torch.mean(t2), 0)
405        self.assertEqual(bar(t2), 1)
406        self.assertEqual(torch.mean(t3), 4.0)
407        self.assertEqual(bar(t3), 0)
408
409    def test_has_torch_function_non_sequence(self):
410        with self.assertRaisesRegex(TypeError, "expected a sequence"):
411            has_torch_function(object())
412
413    def test_mm_semantics(self):
414        """Test that a function with multiple arguments can be overridden"""
415        t1 = DiagonalTensor(5, 2)
416        t2 = torch.eye(5) * 2
417        t3 = SubTensor([[1, 2], [1, 2]])
418        t4 = SubDiagonalTensor(5, 2)
419        # only DiagonalTensor so should always get DiagonalTensor result
420        self.assertEqual(torch.mm(t1, t1), 0)
421        # tensor and DiagonalTensor, always return DiagonalTensor result
422        self.assertEqual(torch.mm(t1, t2), 0)
423        self.assertEqual(torch.mm(t2, t1), 0)
424        # only SubTensor so should always get SubTensor result
425        self.assertEqual(torch.mm(t3, t3), -1)
426        # tensor and SubTensor so should always get SubTensor result
427        self.assertEqual(torch.mm(t3, t2), -1)
428        self.assertEqual(torch.mm(t2, t3), -1)
429        # DiagonalTensor and SubTensor are unrelated classes so the result
430        # depends on which argument appears first
431        self.assertEqual(torch.mm(t3, t1), -1)
432        self.assertEqual(torch.mm(t1, t3), 0)
433        # SubDiagonalTensor should take precedence over DiagonalTensor
434        # but should behave otherwise the same as DiagonalTensor
435        self.assertEqual(torch.mm(t4, t4), 1)
436        self.assertEqual(torch.mm(t4, t1), 1)
437        self.assertEqual(torch.mm(t1, t4), 1)
438        self.assertEqual(torch.mm(t4, t2), 1)
439        self.assertEqual(torch.mm(t2, t4), 1)
440        self.assertEqual(torch.mm(t3, t4), -1)
441        self.assertEqual(torch.mm(t4, t3), 1)
442
443    def test_precedence_semantics(self):
444        """Test semantics for __torch_function__ for functions that take
445        multiple arguments
446
447        For functions that take multiple arguments, the appropriate
448        __torch_function__ implementation to call is determined by
449        examining the types of the arguments. The precedence order is
450        left-to-right in the argument list, except subclasses are always
451        checked before superclasses. The first result of calling the
452        implementations in precedence order that is not NotImplemented
453        is returned to the user. If all implementations return
454        NotImplemented, a TypeError is raised.
455
456        All cases are tested with functions implemented in C++ and
457        either foo or baz, which are python functions defined above that
458        are instrumented to obey the same dispatch rules as the
459        functions in torch.functional.
460        """
461        # DiagonalTensor has a valid override and SubDiagonal has an
462        # override that returns NotImplemented so we should call the
463        # DiagonalTensor implementation, returning -1
464        t1 = DiagonalTensor(5, 2)
465        t2 = SubDiagonalTensor(5, 2)
466        self.assertEqual(torch.div(t1, t2), -1)
467        self.assertEqual(torch.div(t2, t1), -1)
468        self.assertEqual(foo(t1, t2), -1)
469        self.assertEqual(foo(t2, t1), -1)
470
471        # SubTensor has an implementation that returns NotImplemented as
472        # well so it should behave exactly like SubDiagonalTensor in the
473        # test above
474        t3 = SubTensor([[1, 2], [1, 2]])
475        self.assertEqual(torch.div(t1, t3), -1)
476        self.assertEqual(torch.div(t3, t1), -1)
477        self.assertEqual(foo(t1, t3), -1)
478        self.assertEqual(foo(t3, t1), -1)
479
480        # div between SubTensor and SubDiagonalTensor should raise
481        # TypeError since both have an implementation that
482        # explicitly returns NotImplemented
483        with self.assertRaises(TypeError):
484            torch.div(t2, t3)
485        with self.assertRaises(TypeError):
486            torch.div(t3, t2)
487        with self.assertRaises(TypeError):
488            foo(t2, t3)
489        with self.assertRaises(TypeError):
490            foo(t3, t2)
491
492        # none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a
493        # mul or a baz implementation so all ops should raise TypeError
494        with self.assertRaises(TypeError):
495            torch.mul(t1, t1)
496        with self.assertRaises(TypeError):
497            torch.mul(t1, t2)
498        with self.assertRaises(TypeError):
499            torch.mul(t1, t3)
500        with self.assertRaises(TypeError):
501            torch.mul(t2, t1)
502        with self.assertRaises(TypeError):
503            torch.mul(t2, t2)
504        with self.assertRaises(TypeError):
505            torch.mul(t2, t3)
506        with self.assertRaises(TypeError):
507            torch.mul(t3, t1)
508        with self.assertRaises(TypeError):
509            torch.mul(t3, t2)
510        with self.assertRaises(TypeError):
511            torch.mul(t3, t3)
512        with self.assertRaises(TypeError):
513            baz(t1, t1)
514        with self.assertRaises(TypeError):
515            baz(t1, t2)
516        with self.assertRaises(TypeError):
517            baz(t1, t3)
518        with self.assertRaises(TypeError):
519            baz(t2, t1)
520        with self.assertRaises(TypeError):
521            baz(t2, t2)
522        with self.assertRaises(TypeError):
523            baz(t2, t3)
524        with self.assertRaises(TypeError):
525            baz(t3, t1)
526        with self.assertRaises(TypeError):
527            baz(t3, t2)
528        with self.assertRaises(TypeError):
529            baz(t3, t3)
530
531    def test_user_implementation_raises(self):
532        """Test that errors raised in user implementations propagate correctly"""
533        t1 = DiagonalTensor(5, 2)
534        t2 = DiagonalTensor(5, 2)
535        with self.assertRaises(ValueError):
536            torch.add(t1, t2)
537        with self.assertRaises(ValueError):
538            quux(t1)
539
540    def test_tensor_subclass_propagation(self):
541        """this test exercises the functionality described in
542        docs/source/notes/extending.rst#subclassing-torchtensor"""
543        t1 = torch.tensor([5])
544        t2 = torch.tensor([6])
545
546        s1 = SubTensor2([5])
547        s2 = SubTensor2([6])
548
549        ss1 = SubSubTensor2([5])
550        ss2 = SubSubTensor2([6])
551
552        sn1 = SubTensor3([5])
553        sn2 = SubTensor3([6])
554
555        # Check that leaf subclass is kept regardless of order
556        self.assertTrue(isinstance(s1 + t2, SubTensor2))
557        self.assertTrue(isinstance(t1 + s2, SubTensor2))
558        self.assertTrue(isinstance(s1 + s2, SubTensor2))
559
560        # Check indexing subclass is kept
561        self.assertTrue(isinstance(s1[0], SubTensor2))
562
563        # Check case for subclass of subclass.
564        self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
565        self.assertTrue(isinstance(ss1 + s2, SubSubTensor2))
566        self.assertTrue(isinstance(s1 + ss2, SubSubTensor2))
567        self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
568        self.assertTrue(isinstance(ss1 + t2, SubSubTensor2))
569        self.assertTrue(isinstance(t1 + ss2, SubSubTensor2))
570        self.assertTrue(isinstance(ss1[0], SubSubTensor2))
571
572        # Make sure unrelated class trees are not merged.
573        with self.assertRaises(TypeError):
574            s1 + sn2
575        with self.assertRaises(TypeError):
576            sn1 + s2
577
578    def test_base(self):
579        # https://github.com/szagoruyko/pytorchviz/issues/65
580        class DummyTensor(torch.Tensor):
581            pass
582
583        a = torch.ones(1)
584        c = DummyTensor(a)
585        self.assertTrue(c._is_view())
586        self.assertTrue(c._base is a)
587
588    def test_grad(self):
589        # Previously, Tensor-like objects that did not subclass from Tensor
590        # did not get wrapped into unary tuples before being passed into
591        # handle_torch_function, in contradiction with how Tensor-likes
592        # were handled
593        #
594        # NB: this asserts that the arguments get normalized into a tuple
595        # before entering the torch function handler; it could go the
596        # other way but beware https://github.com/pytorch/pytorch/issues/76037
597
598        class Dummy:
599            @classmethod
600            def __torch_function__(cls, func, types, args=(), kwargs=None):
601                inputs, outputs = args
602                self.assertEqual(inputs, (x,))
603                self.assertEqual(outputs, (x,))
604                return -1
605
606        x = Dummy()
607        self.assertEqual(torch.autograd.grad(x, x), -1)
608
609    def test_pow_rpow(self):
610        class NothingImplemented(torch.Tensor):
611            @classmethod
612            def __torch_function__(cls, func, types, args=(), kwargs=None):
613                return NotImplemented
614
615        class RPowOnly(torch.Tensor):
616            @classmethod
617            def __torch_function__(cls, func, types, args=(), kwargs=None):
618                if func is torch.Tensor.__rpow__:
619                    return -1
620                return NotImplemented
621
622        self.assertEqual(NothingImplemented() ** RPowOnly(), -1)
623
624
625def generate_tensor_like_override_tests(cls):
626    from torch.testing._internal.generated.annotated_fn_args import annotated_args
627
628    def test_generator(func, override):
629        # If func corresponds to a torch.Tensor method or property.
630        if is_tensor_method_or_property(func):
631            # Generate an instance by using SubTensor,
632            def instance_gen():
633                return SubTensor([5])
634        else:
635            # Otherwise, TensorLike.
636            def instance_gen():
637                return TensorLike()
638
639        # FIXME The following code does not support kwonly args without defaults.
640        # The fix is easy, as one just needs to save these args when generating the variable
641        # annotated_args. The problem is that, if one does so, one finds a number
642        # of functions that have problematic signatures in native_functions.yaml.
643        # Fixing these would be BC breaking, so hence this terrible hack
644        # https://github.com/pytorch/pytorch/issues/67008
645        kwargs = {}
646        if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__:
647            kwargs = {"upper": True}
648
649        func_args = []
650        is_method = is_tensor_method_or_property(func)
651
652        def _simple_type_parser(func, arg_name, arg_type):
653            # Guess valid input to aten function based on type of argument
654            if arg_type == "Tensor":
655                return instance_gen()
656            elif arg_type == "TensorList" or arg_type == "ITensorListRef":
657                return [instance_gen(), instance_gen()]
658            elif arg_type == "c10::List<::std::optional<Tensor>>":
659                return [instance_gen(), instance_gen()]
660            elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef":
661                size = arg.get("size", 2)
662                if size == 1:
663                    return 1
664                else:
665                    return [1] * size
666            elif arg_type == "Scalar":
667                return 3.5
668            elif arg_type == "bool":
669                return False
670            elif arg_type == "Dimname":
671                return ""
672            elif arg_type == "DimnameList":
673                return [""]
674            elif arg_type.startswith("int"):
675                return 0
676            elif arg_type in {"Stream"}:
677                return torch.Stream()
678            elif arg_type.startswith("float") or arg_type == "double":
679                return 1.0
680            elif arg_type in {"Generator", "MemoryFormat", "TensorOptions"}:
681                return None
682            elif arg_type == "ScalarType":
683                return torch.float32
684            elif arg_type == "c10::string_view":
685                return ""
686            elif arg_type == "SymInt":
687                # TODO: generate actual SymbolicInt
688                return 1
689            else:
690                raise RuntimeError(
691                    f"Unsupported argument type {arg_type} for {arg_name} of function {func}"
692                )
693
694        if func in annotated_args:
695            for arg in annotated_args[func]:
696                # Guess valid input to aten function based on type of argument
697                t = arg["simple_type"]
698                if t.endswith("?"):
699                    t = t[:-1]
700                if t == "Tensor" and is_method and arg["name"] == "self":
701                    # See "Note: properties and __get__"
702                    func = func.__get__(instance_gen())
703                    continue
704                arg_to_add = _simple_type_parser(func, arg["name"], t)
705                if "is_kwarg_only" in arg and arg["is_kwarg_only"] == str(True):
706                    kwargs[arg["name"]] = arg_to_add
707                else:
708                    func_args.append(arg_to_add)
709        else:
710            args = inspect.getfullargspec(override)
711            try:
712                func_args = inspect.getfullargspec(func)
713                # Remove annotations from argspec
714                func_args = type(func_args)(**{**func_args, 'annotations': None})
715                if func_args != args:
716                    raise RuntimeError(f"Override for {func} doesn't match its argspec.\n"
717                                       + f"Original: {inspect.signature(func)}\n"
718                                       + f"Override: {inspect.signature(override)}")
719            except TypeError:
720                pass
721            nargs = len(args.args)
722            if args.defaults is not None:
723                nargs -= len(args.defaults)
724            func_args = [instance_gen() for _ in range(nargs)]
725            if args.varargs is not None:
726                func_args += [instance_gen(), instance_gen()]
727
728        def test(self):
729            ret = func(*func_args, **kwargs)
730            # ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
731            # This is currently the best check but doesn't work for, for example,
732            # Tensor.__add__ because it redirects to Tensor.add.
733            # See note "_triggered wrapper"
734            if not is_method or ret is None:
735                self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
736                return
737
738            self.assertEqual(ret, -1)
739
740        return test
741
742    for func, override in get_testing_overrides().items():
743        test_method = test_generator(func, override)
744        if func.__name__ == "__get__":
745            # Note: properties and __get__
746            # __get__ is part of the descriptor protocol.
747            # https://docs.python.org/3/howto/descriptor.html
748            # This is used for properties of the form
749            # torch.Tensor.<property>, with the method __get__
750            # In this case we get the property name in two ways:
751
752            # This case for properties defined in C.
753            module = getattr(
754                func.__self__,
755                "__qualname__",
756                None
757            )
758
759            # This one for properties defined in Python.
760            if module is None:
761                module = "Tensor." + func.__self__.fget.__name__
762
763            # Unfortunately I couldn't find a way to unify these two cases
764            # and there is no way for general descriptors.
765        elif is_tensor_method_or_property(func):
766            module = "Tensor"
767        else:
768            module = func.__module__
769        if module:
770            name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
771        else:
772            name = f'test_{func.__name__}'
773        test_method.__name__ = name
774        setattr(cls, name, test_method)
775
776generate_tensor_like_override_tests(TestTorchFunctionOverride)
777
778class Wrapper:
779    "Basic data container that knows how to unwrap itself"
780    def __init__(self, data):
781        self.__dict__["_data"] = data
782        self.__dict__["used_attrs"] = set()
783        self.__dict__["used_calls"] = set()
784
785    def __getattr__(self, name):
786        if name in self.__dict__:
787            return self.__dict__[name]
788        self.used_attrs.add(name)
789
790        val = getattr(self._data, name)
791
792        # If it's a method
793        if not isinstance(val, torch.device) and callable(val):
794            c = getattr(type(self._data), name)
795            # Don't append self to args if classmethod/staticmethod
796            if c is val:
797                return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
798            # Otherwise append self to args
799            return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
800
801        return wrap(val)
802
803    def __setattr__(self, name, value):
804        if name in self.__dict__:
805            self.__dict__[name] = value
806
807        self.used_attrs.add(name)
808        setattr(self._data, name, unwrap(value))
809
810    def __setitem__(self, key, value):
811        self._data[unwrap(key)] = unwrap(value)
812
813    def __getitem__(self, key):
814        return wrap(self._data[unwrap(key)])
815
816    @classmethod
817    def __torch_function__(cls, func, types, args=(), kwargs=None):
818        if kwargs is None:
819            kwargs = {}
820        # Find an instance of this class in the arguments
821        args_of_this_cls = []
822        for a in args:
823            if isinstance(a, cls):
824                args_of_this_cls.append(a)
825            elif isinstance(a, collections.abc.Sequence):
826                args_of_this_cls.extend(el for el in a if isinstance(el, cls))
827        assert len(args_of_this_cls) > 0
828        for a in args_of_this_cls:
829            a.used_calls.add(func)
830        args = unwrap(tuple(args))
831        kwargs = {k: unwrap(v) for k, v in kwargs.items()}
832
833        return wrap(func(*args, **kwargs))
834
835    def __add__(self, other):
836        return self.__torch_function__(torch.add, (Wrapper,), (self, other))
837
838    def __mul__(self, other):
839        return self.__torch_function__(torch.mul, (Wrapper,), (self, other))
840
841    def __sub__(self, other):
842        return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
843
844    def __truediv__(self, other):
845        return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
846
847    def __floordiv__(self, other):
848        return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
849
850    def __ge__(self, other):
851        return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
852
853    def __gt__(self, other):
854        return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
855
856    def __lt__(self, other):
857        return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
858
859    def __le__(self, other):
860        return self.__torch_function__(torch.le, (Wrapper,), (self, other))
861
862    def __eq__(self, other):
863        return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
864
865    def __ne__(self, other):
866        return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
867
868    def __bool__(self):
869        return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
870
871    def __int__(self):
872        return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
873
874    def __len__(self):
875        return len(self._data)
876
877
878# unwrap inputs if necessary
879def unwrap(v):
880    if type(v) in {tuple, list}:
881        return type(v)(unwrap(vi) for vi in v)
882
883    return v._data if isinstance(v, Wrapper) else v
884
885# wrap inputs if necessary
886def wrap(v):
887    if type(v) in {tuple, list}:
888        return type(v)(wrap(vi) for vi in v)
889
890    return Wrapper(v) if isinstance(v, torch.Tensor) else v
891
892class TestEinsumOverride(TestCase):
893    "Regression test for gh-38479"
894    def test_wrapper(self):
895        x = Wrapper(torch.randn(5))
896        y = Wrapper(torch.randn(4))
897        self.assertEqual(torch.einsum('i,j->ij', x, y)._data,
898                         torch.ger(x, y)._data)
899
900        # in the old einsum interface, `operands` is a list
901        a = Wrapper(torch.randn(2, 3))
902        b = Wrapper(torch.randn(5, 3, 7))
903        c = Wrapper(torch.randn(2, 7))
904        self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data,
905                         torch.nn.functional.bilinear(a, c, b)._data)
906
907class TestGradCheckOverride(TestCase):
908    "Test that wrappers work with gradcheck."
909    def test_gradcheck(self):
910        from torch.testing._internal.common_utils import gradcheck, gradgradcheck
911
912        def run_test(fast_mode):
913            a = wrap(torch.tensor(5.0, dtype=torch.double))
914            b = wrap(torch.tensor(6.0, dtype=torch.double))
915
916            a.requires_grad = True
917            b.requires_grad = True
918
919            gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
920            gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
921
922            total_used_attrs = a.used_attrs.union(b.used_attrs)
923            total_used_calls = a.used_calls.union(b.used_calls)
924
925            # These attributes (and the functions below) may change
926            # if the gradcheck implementation changes. It's best to
927            # aim for attributes that may be commonly present on other
928            # Tensor-likes.
929            expected_used_attrs = {
930                'data',
931                'dtype',
932                'is_floating_point',
933                'is_sparse',
934                'layout',
935                'new_zeros',
936                'numel',
937                'requires_grad',
938                'requires_grad_',
939                'size',
940                'stride',
941            }
942            if fast_mode:
943                expected_used_attrs.add('is_complex')
944                expected_used_attrs.add('device')
945            self.assertEqual(expected_used_attrs, total_used_attrs)
946
947            expected_used_calls = {
948                torch.Tensor.new_zeros,
949                torch.Tensor.size,
950                torch.Tensor.is_floating_point,
951                torch.Tensor.numel,
952                torch.Tensor.stride,
953                torch.Tensor.requires_grad_,
954                torch.autograd.grad,
955                torch.add,
956            }
957            if fast_mode:
958                expected_used_calls.add(torch.Tensor.is_complex)
959            self.assertEqual(expected_used_calls, total_used_calls)
960        run_test(fast_mode=True)
961        run_test(fast_mode=False)
962
963class TestNamedTuple(TestCase):
964    """ Regression test for gh-47090 """
965    def test_max(self):
966        x = torch.tensor([1, 2])
967        xs = x.as_subclass(SubTensor2)
968        r = torch.max(x, dim=0)
969        rs = torch.max(xs, dim=0)
970        self.assertEqual(type(r), type(rs))
971        self.assertEqual(r, rs)
972
973class TestGradNewOnesOverride(TestCase):
974    """ Regression test for gh-47069 """
975    def test_newones(self):
976        t = torch.tensor([1, 2]).as_subclass(SubTensor2)
977        n = t.new_ones((1, 2))
978        self.assertEqual(type(n), SubTensor2)
979
980class TestPickle(TestCase):
981    "Regression test for gh-47051"
982    def test_pickle(self):
983        t = torch.tensor([1]).as_subclass(SubTensor2)
984        t.abcd = "e"
985        t2 = pickle.loads(pickle.dumps(t))
986        self.assertIs(type(t2), SubTensor2)
987        self.assertEqual(t2.abcd, "e")
988
989class TestBroadcastAllOverride(TestCase):
990    """ test for gh-37141 """
991    def test_broadcast_all(self):
992        from torch.distributions.utils import broadcast_all
993        a = torch.tensor([1.2, 3.4, 5.6])
994        a_w = Wrapper(a)
995        b = torch.tensor(5.0)
996        b_w = Wrapper(b)
997        c = torch.tensor([5.0, 5.0, 5.0])
998
999        o_1 = broadcast_all(a_w, b_w)
1000        self.assertTrue(isinstance(o_1[0], Wrapper))
1001        self.assertTrue(isinstance(o_1[1], Wrapper))
1002        self.assertEqual(o_1[0]._data, a)
1003        self.assertEqual(o_1[1]._data, c)
1004
1005        o_2 = broadcast_all(a_w, b)
1006        self.assertTrue(isinstance(o_2[0], Wrapper))
1007        self.assertTrue(isinstance(o_2[1], Wrapper))
1008        self.assertEqual(o_2[0]._data, a)
1009        self.assertEqual(o_2[1]._data, c)
1010
1011class TestWrapTorchFunction(TestCase):
1012    def test_wrap_torch_function(self):
1013        class A:
1014            @classmethod
1015            def __torch_function__(cls, func, types, args, kwargs):
1016                return -1
1017
1018        def dispatcher(a):
1019            return (a,)
1020
1021        @torch.overrides.wrap_torch_function(dispatcher)
1022        def f(a):
1023            return a
1024
1025        self.assertEqual(f(A()), -1)
1026
1027class TestIndexing(TestCase):
1028    """ Regression tests for gh-46277 """
1029    def test_getitem(self):
1030        class A:
1031            @classmethod
1032            def __torch_function__(cls, func, types, args, kwargs=None):
1033                return -1
1034
1035        t = torch.tensor([5])
1036        self.assertEqual(t[A()], -1)
1037        self.assertEqual(t, torch.tensor([5]))
1038
1039    def test_getitem_subclass(self):
1040        class A(torch.Tensor):
1041            @classmethod
1042            def __torch_function__(cls, func, types, args, kwargs=None):
1043                return -1
1044
1045        t = torch.tensor([5])
1046        self.assertEqual(t[A()], -1)
1047        self.assertEqual(t[5, A()], -1)
1048        self.assertEqual(t, torch.tensor([5]))
1049
1050    def test_setitem(self):
1051        triggered = set()
1052
1053        class A:
1054            @classmethod
1055            def __torch_function__(cls, func, types, args, kwargs=None):
1056                triggered.add(func)
1057                return -1
1058
1059        t = torch.tensor([5])
1060        t[A()] = 1
1061        t[5, A()] = 1
1062        self.assertIn(Tensor.__setitem__, triggered)
1063        self.assertEqual(t, torch.tensor([5]))
1064
1065    def test_setitem_val(self):
1066        triggered = set()
1067
1068        class A:
1069            @classmethod
1070            def __torch_function__(cls, func, types, args, kwargs=None):
1071                triggered.add(func)
1072                return -1
1073
1074        t = torch.tensor([5])
1075        t[0] = A()
1076        self.assertIn(Tensor.__setitem__, triggered)
1077        self.assertEqual(t, torch.tensor([5]))
1078
1079    def test_setitem_subclass(self):
1080        triggered = set()
1081
1082        class A(torch.Tensor):
1083            @classmethod
1084            def __torch_function__(cls, func, types, args, kwargs=None):
1085                triggered.add(func)
1086                return -1
1087
1088        t = torch.tensor([5])
1089        t[A()] = 1
1090        t[5, A()] = 1
1091        self.assertIn(Tensor.__setitem__, triggered)
1092        self.assertEqual(t, torch.tensor([5]))
1093
1094
1095class TestIterator(TestCase):
1096    # Regression test for gh-54457
1097    def test_iterator(self):
1098        t = torch.tensor([5, 6, 7]).as_subclass(SubTensor2)
1099        it = iter(t)
1100        self.assertIs(type(next(it)), SubTensor2)
1101        self.assertIs(type(next(it)), SubTensor2)
1102        self.assertIs(type(next(it)), SubTensor2)
1103
1104
1105class TestRNN(TestCase):
1106    # Regression test for gh-55868
1107    def test_rnn(self):
1108        model = torch.nn.RNN(10, 20, 2)
1109        input = Wrapper(torch.randn(1, 5, 10))
1110        model(input)
1111
1112
1113class TestDisabledTorchFunction(TestCase):
1114    # Regression test for gh-64687
1115    def test_parameter_does_not_prevent_dispatch(self):
1116        class MyTensor:
1117            @classmethod
1118            def __torch_function__(cls, func, types, args=(), kwargs=None):
1119                return "called"
1120
1121        t1 = MyTensor()
1122        t2 = torch.nn.Parameter(torch.rand(2, 2))
1123        self.assertEqual(torch.add(t2, t1), "called")
1124
1125        inp = torch.rand(10, 10)
1126        self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
1127        self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
1128
1129class TestResolveName(TestCase):
1130    def test_resolve_name(self):
1131        for cs in get_overridable_functions().values():
1132            for c in cs:
1133                self.assertEqual(
1134                    eval(torch.overrides.resolve_name(c)),
1135                    c,
1136                    msg=f"{c}, {torch.overrides.resolve_name(c)}"
1137                )
1138
1139class TestTorchFunctionWarning(TestCase):
1140    def test_warn_on_invalid_torch_function_standalone_class(self):
1141        class StandaloneTorchFunctionClass:
1142            def __torch_function__(self, *args, **kwargs):
1143                pass
1144        a = StandaloneTorchFunctionClass()
1145        with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1146            # Function that handles torch_function on the python side
1147            torch.nn.functional.dropout(a)
1148        with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1149            # Function that handles torch_function in C++
1150            torch.abs(a)
1151
1152    def test_warn_on_invalid_torch_function_tensor_subclass(self):
1153        class TensorSubclassTorchFunctionClass(torch.Tensor):
1154            def __torch_function__(self, *args, **kwargs):
1155                pass
1156        b = TensorSubclassTorchFunctionClass()
1157        with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1158            # Function that handles torch_function on the python side
1159            torch.nn.functional.dropout(b)
1160        with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1161            # Function that handles torch_function in C++
1162            torch.abs(b)
1163
1164class TestDisabledUserWarnings(TestCase):
1165    def test_no_implicit_user_warning_for_deprecated_functions(self):
1166        self.assertNotWarn(get_ignored_functions)
1167        self.assertNotWarn(get_testing_overrides)
1168        self.assertNotWarn(get_overridable_functions)
1169        self.assertNotWarn(lambda: resolve_name(torch.Tensor.add))
1170        self.assertNotWarn(lambda: is_tensor_method_or_property(torch.Tensor.add))
1171
1172@unittest.skipIf(TEST_WITH_CROSSREF, "not run with crossref")
1173class TestTorchFunctionMode(TestCase):
1174    def test_basic(self):
1175        class A(TorchFunctionMode):
1176            def __torch_function__(self, *args, **kwargs):
1177                return -1
1178        # NB: factory functions get overridden too!
1179        x = torch.randn(1)
1180        with A():
1181            self.assertEqual(torch.randn(3), -1)
1182            self.assertEqual(torch.add(x, x), -1)
1183            self.assertEqual(torch.split(None, [2]), -1)  # python side
1184            self.assertEqual(bar(x), -1)
1185
1186    def test_factory_override(self):
1187        class A(TorchFunctionMode):
1188            def __torch_function__(self, *args, **kwargs):
1189                return -1
1190
1191        with A():
1192            self.assertEqual(torch.tensor([1]), -1)
1193            self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1)
1194            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1)
1195            self.assertEqual(torch.sparse_coo_tensor(1, 1, (1, 1), check_invariants=False), -1)
1196            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1, (1, 1), check_invariants=False), -1)
1197            self.assertEqual(torch.as_tensor([1]), -1)
1198
1199    def test_modes_handle_first(self):
1200        class A(TorchFunctionMode):
1201            def __torch_function__(self, *args, **kwargs):
1202                return -40
1203
1204        x = SubTensor()
1205        with A():
1206            self.assertEqual(torch.neg(x), -40)
1207            self.assertEqual(torch.mean(x), -40)
1208            self.assertEqual(torch.mm(x, x), -40)
1209            self.assertEqual(bar(x), -40)
1210
1211    def test_modes_return_notimplemented(self):
1212        class MyMode(TorchFunctionMode):
1213            def __torch_function__(self, *args, **kwargs):
1214                return NotImplemented
1215
1216        x = SubTensor()
1217        with MyMode():
1218            self.assertEqual(torch.mean(x), 0)
1219            self.assertEqual(torch.mm(x, x), -1)
1220            self.assertEqual(bar(x), 1)
1221            self.assertRaisesRegex(
1222                TypeError, r'SubTensor',
1223                lambda: self.assertEqual(torch.max(x, x)))
1224
1225    def test_with_mode(self):
1226        class ErrorA(RuntimeError):
1227            pass
1228
1229        class A(TorchFunctionMode):
1230            def __torch_function__(self, *args, **kwargs):
1231                raise ErrorA
1232
1233        with self.assertRaises(ErrorA):
1234            with A():
1235                torch.empty([])
1236
1237    def test_with_mode_created_separately(self):
1238        class ErrorA(RuntimeError):
1239            pass
1240
1241        class A(TorchFunctionMode):
1242            def __torch_function__(self, *args, **kwargs):
1243                raise ErrorA
1244
1245        x = A()
1246        with self.assertRaises(ErrorA):
1247            with x:
1248                torch.empty([])
1249
1250    def test_with_nested_modes(self):
1251        out = []
1252
1253        class A(TorchFunctionMode):
1254            def __init__(self, msg):
1255                self.msg = msg
1256
1257            def __torch_function__(self, func, _, args=(), kwargs=None):
1258                if kwargs is None:
1259                    kwargs = {}
1260                out.append(self.msg)
1261                return func(*args, **kwargs)
1262
1263        with A("layer1"):
1264            with A("layer2"):
1265                torch.empty([])
1266
1267        self.assertEqual(out, ["layer2", "layer1"])
1268
1269    def test_nested_same_mode(self):
1270        out = []
1271
1272        class A(TorchFunctionMode):
1273            def __init__(self, msg):
1274                self.msg = msg
1275
1276            def __torch_function__(self, func, _, args=(), kwargs=None):
1277                if kwargs is None:
1278                    kwargs = {}
1279                out.append(self.msg)
1280                return func(*args, **kwargs)
1281
1282        with A("layer1") as a:
1283            with a:
1284                torch.empty([])
1285
1286        self.assertEqual(out, ["layer1", "layer1"])
1287
1288    def test_error_using_class_method_on_mode(self):
1289        class A(TorchFunctionMode):
1290            @classmethod
1291            def __torch_function__(cls, func, _, args=(), kwargs=None):
1292                return func(args, kwargs)
1293
1294        x = torch.tensor(5.)
1295        with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"):
1296            with A():
1297                x + x
1298
1299    def test_restacking_with_ancestor(self):
1300        class A(TorchFunctionMode):
1301            pass
1302
1303        with A():
1304            with A() as x:
1305                pass
1306
1307        with x:
1308            pass
1309
1310    def test_get_cur_mode(self):
1311        class A(TorchFunctionMode):
1312            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1313                pass
1314
1315        with A() as mode1:
1316            self.assertEqual(_get_current_function_mode(), mode1)
1317
1318        with mode1:
1319            with A() as mode2:
1320                self.assertEqual(_get_current_function_mode(), mode2)
1321
1322
1323    def test_get_mode_stack(self):
1324        class A(TorchFunctionMode):
1325            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1326                pass
1327
1328        self.assertEqual(_get_current_function_mode_stack(), [])
1329
1330        with A() as mode1:
1331            self.assertEqual(_get_current_function_mode_stack(), [mode1])
1332
1333        with mode1:
1334            with A() as mode2:
1335                self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2])
1336
1337    def test_all_same_mode(self):
1338        class A(TorchFunctionMode):
1339            pass
1340
1341        x = A()
1342        y = A()
1343        self.assertTrue(all_same_mode([x, x, x]))
1344        self.assertFalse(all_same_mode([x, None]))
1345        self.assertFalse(all_same_mode([x, y]))
1346
1347    def test_nested_modes_with_python_has_torch_function(self):
1348        called = []
1349
1350        class A(TorchFunctionMode):
1351            def __torch_function__(self, func, types, args=(), kwargs=None):
1352                called.append("A")
1353                kwargs = {} if kwargs is None else kwargs
1354                return func(*args, **kwargs)
1355
1356        class B(TorchFunctionMode):
1357            def __torch_function__(self, func, types, args=(), kwargs=None):
1358                called.append("B")
1359                kwargs = {} if kwargs is None else kwargs
1360                return func(*args, **kwargs)
1361
1362        x = torch.randn(3, 4)
1363        with A():
1364            with B():
1365                y = bar(x)
1366
1367        self.assertEqual(y, x)
1368        self.assertEqual(called, ["B", "A"])
1369
1370
1371    def test_reentrant_mode_idiom(self):
1372        log = []
1373
1374        class A(TorchFunctionMode):
1375            def __torch_function__(self, func, types, args=(), kwargs=None):
1376                if kwargs is None:
1377                    kwargs = {}
1378                log.append(func)
1379                if func is torch.sub:
1380                    with self:
1381                        input, other = args
1382                        assert not kwargs
1383                        return torch.add(input, other, alpha=-1)
1384                return func(*args, **kwargs)
1385
1386        x = torch.randn(1)
1387        y = torch.randn(1)
1388        with A():
1389            torch.sub(x, y)
1390        # add hits the torch function again!
1391        self.assertEqual(log, [torch.sub, torch.add])
1392
1393    def test_nn_parse_to(self):
1394        # This failed because the parser thinks the function is called to()
1395        # but it's actually called _parse_to()
1396
1397        called = False
1398
1399        class A(TorchFunctionMode):
1400            def __torch_function__(self, func, types, args=(), kwargs=None):
1401                nonlocal called
1402                if kwargs is None:
1403                    kwargs = {}
1404                called = True
1405                return func(*args, **kwargs)
1406
1407        with A():
1408            torch._C._nn._parse_to('cpu')
1409
1410        self.assertTrue(called)
1411
1412    def test_getitem_call(self):
1413        # This failed because the parser thinks the function is called to()
1414        # but it's actually called _parse_to()
1415
1416        called = False
1417
1418        class A(TorchFunctionMode):
1419            def __torch_function__(self, func, types, args=(), kwargs=None):
1420                nonlocal called
1421                if kwargs is None:
1422                    kwargs = {}
1423                called = True
1424                return func(*args, **kwargs)
1425
1426        a = torch.zeros(5)
1427        b = torch.tensor(0)
1428        with A():
1429            a[b]
1430
1431        self.assertTrue(called)
1432
1433
1434    def test_distributions_bernoulli(self):
1435        # This failed because improper use of has_torch_function when
1436        # is_tensor_like should have been used instead, inside the
1437        # broadcasting logic called by distributions (Bernoulli doesn't
1438        # matter per se)
1439
1440        called = False
1441
1442        class A(TorchFunctionMode):
1443            def __torch_function__(self, func, types, args=(), kwargs=None):
1444                nonlocal called
1445                if kwargs is None:
1446                    kwargs = {}
1447                called = True
1448                return func(*args, **kwargs)
1449
1450        with A():
1451            torch.distributions.Bernoulli(0.3)
1452
1453        self.assertTrue(called)
1454
1455    def test_mode_notimplemented_loop(self):
1456        # Default tensor subclass implementation disables torch function;
1457        # when we redispatch to mode we must not treat the objects as
1458        # eligible
1459
1460        called = 0
1461
1462        class A(TorchFunctionMode):
1463            def __torch_function__(self, func, types, args=(), kwargs=None):
1464                nonlocal called
1465                if kwargs is None:
1466                    kwargs = {}
1467                called += 1
1468                # The first time we call, the mode sees an active type that
1469                # it doesn't know how to deal with.  The second time, we're
1470                # instructed to treat it "as if it were a tensor", and so
1471                # we keep going.  I'm not entirely clear if the subclasses
1472                # disappearing from types is the correct way to do it.
1473                if any(t is not torch.Tensor for t in types):
1474                    return NotImplemented
1475                else:
1476                    return func(*args, **kwargs)
1477
1478        class B(torch.Tensor):
1479            pass
1480
1481        b = B()
1482
1483        with A():
1484            r = torch.neg(b)
1485
1486        self.assertIs(type(r), B)
1487        self.assertEqual(called, 2)
1488
1489        called = 0
1490
1491        with A():
1492            r = bar(b)
1493
1494        self.assertIs(type(r), B)
1495        self.assertEqual(called, 2)
1496
1497    def test_disable_subclass_not_mode(self):
1498        called = False
1499
1500        class A(TorchFunctionMode):
1501            def __torch_function__(self, func, types, args=(), kwargs=None):
1502                nonlocal called
1503                if kwargs is None:
1504                    kwargs = {}
1505                called = True
1506                return func(*args, **kwargs)
1507
1508        class B(torch.Tensor):
1509            pass
1510
1511        x = B(torch.randn(5))
1512        with A():
1513            with torch._C.DisableTorchFunctionSubclass():
1514                self.assertNotIsInstance(torch.sum(x), B)
1515
1516        self.assertTrue(called)
1517
1518    def test_disable_subclass_mode(self):
1519        called = False
1520
1521        class A(TorchFunctionMode):
1522            def __torch_function__(self, func, types, args=(), kwargs=None):
1523                nonlocal called
1524                if kwargs is None:
1525                    kwargs = {}
1526                called = True
1527                return func(*args, **kwargs)
1528
1529        class B(torch.Tensor):
1530            pass
1531
1532        x = B(torch.randn(5))
1533        with A():
1534            with torch._C.DisableTorchFunction():
1535                self.assertNotIsInstance(torch.sum(x), B)
1536
1537        self.assertFalse(called)
1538
1539    def test_disable_enable_subclass(self):
1540        called = False
1541
1542        class A(torch.Tensor):
1543            pass
1544
1545        x = A(torch.randn(5))
1546        with torch._C.DisableTorchFunctionSubclass():
1547            g = torch._C._EnableTorchFunction()
1548            try:
1549                self.assertIsInstance(torch.sum(x), A)
1550            finally:
1551                del g
1552
1553    def test_torch_function_all_disabled_api(self):
1554        from torch._C import _is_torch_function_all_disabled
1555
1556        state = _is_torch_function_all_disabled()
1557        self.assertFalse(state)
1558
1559        with torch._C.DisableTorchFunction():
1560            state = _is_torch_function_all_disabled()
1561            self.assertTrue(state)
1562
1563        state = _is_torch_function_all_disabled()
1564        self.assertFalse(state)
1565
1566        with torch._C.DisableTorchFunctionSubclass():
1567            state = _is_torch_function_all_disabled()
1568            self.assertFalse(state)
1569
1570    def test_subclass_hash(self):
1571        class DiagTensor(torch.Tensor):
1572            def __init__(self, diag):
1573                self._diag = diag
1574
1575            @classmethod
1576            def __torch_function__(cls, func, types, args=(), kwargs=None):
1577                kwargs = kwargs or {}
1578
1579                def get_full_matrices(t):
1580                    if isinstance(t, DiagTensor):
1581                        return torch.diag_embed(t._diag)
1582                    else:
1583                        return t
1584
1585                return func(*tree_map(get_full_matrices, args), **tree_map(get_full_matrices, kwargs))
1586
1587        d = torch.rand(2)
1588        a = DiagTensor(d)
1589
1590        self.assertEqual((a + 1), torch.diag_embed(d) + 1)
1591
1592        # If the hash function was returning the same value, this would
1593        # fail inside `Tensor.__eq__`.
1594        # If __hash__ was going through torch_function, the implementation above would
1595        # be wrong as it would compute the hash on a temporary Tensor thus not ensuring
1596        # the uniqueness of the hash that we rely on for Tensors.
1597        s = set()
1598        s.add(a)
1599        s.add(DiagTensor(d))
1600
1601    def test_custom_device_type(self):
1602        class CustomDeviceContext(TorchFunctionMode):
1603
1604            def __torch_function__(self, func, types, args=(), kwargs=None):
1605                kwargs = kwargs or {}
1606                if func == torch.device:
1607                    if args and isinstance(args[0], int):
1608                        args = ("xla", args[0])
1609                    elif isinstance(kwargs.get('device'), int):
1610                        kwargs['device'] = f"xla:{kwargs.get('device')}"
1611                return func(*args, **kwargs)
1612
1613        with CustomDeviceContext():
1614            d_args = torch.device(0)
1615            self.assertEqual(d_args.type, "xla")
1616            self.assertEqual(d_args.index, 0)
1617            d_kwargs = torch.device(device=0)
1618            self.assertEqual(d_kwargs.type, "xla")
1619            self.assertEqual(d_kwargs.index, 0)
1620
1621    def test_device_context_semantics(self):
1622        from torch._C import _len_torch_function_stack
1623        from torch.utils._device import DeviceContext
1624        try:
1625            torch.set_default_device("cuda")
1626
1627            def get_stack():
1628                return [torch._C._get_function_stack_at(i) for i in range(_len_torch_function_stack())]
1629
1630            base_mode = BaseTorchFunctionMode()
1631            with base_mode:
1632                torch.set_default_device("cpu")
1633                x = torch.ones(2, 2)
1634                stack = get_stack()
1635                self.assertIsInstance(stack[0], DeviceContext)
1636                self.assertEqual(stack[0].device, torch.device("cpu"))
1637
1638            stack = get_stack()
1639            self.assertIsInstance(stack[0], DeviceContext)
1640            self.assertEqual(stack[0].device, torch.device("cpu"))
1641        finally:
1642            torch.set_default_device(None)
1643
1644
1645
1646
1647
1648if __name__ == '__main__':
1649    run_tests()
1650