xref: /aosp_15_r20/external/pytorch/test/test_overrides.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: __torch_function__"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerimport numpy as np
5*da0073e9SAndroid Build Coastguard Workerimport inspect
6*da0073e9SAndroid Build Coastguard Workerimport functools
7*da0073e9SAndroid Build Coastguard Workerimport pprint
8*da0073e9SAndroid Build Coastguard Workerimport pickle
9*da0073e9SAndroid Build Coastguard Workerimport collections
10*da0073e9SAndroid Build Coastguard Workerimport unittest
11*da0073e9SAndroid Build Coastguard Workerimport contextlib
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO
14*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import (
15*da0073e9SAndroid Build Coastguard Worker    handle_torch_function,
16*da0073e9SAndroid Build Coastguard Worker    has_torch_function,
17*da0073e9SAndroid Build Coastguard Worker    get_ignored_functions,
18*da0073e9SAndroid Build Coastguard Worker    get_overridable_functions,
19*da0073e9SAndroid Build Coastguard Worker    get_testing_overrides,
20*da0073e9SAndroid Build Coastguard Worker    resolve_name,
21*da0073e9SAndroid Build Coastguard Worker    is_tensor_method_or_property,
22*da0073e9SAndroid Build Coastguard Worker    TorchFunctionMode,
23*da0073e9SAndroid Build Coastguard Worker    _get_current_function_mode,
24*da0073e9SAndroid Build Coastguard Worker    _get_current_function_mode_stack,
25*da0073e9SAndroid Build Coastguard Worker    BaseTorchFunctionMode
26*da0073e9SAndroid Build Coastguard Worker)
27*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._mode_utils import all_same_mode
28*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard WorkerTensor = torch.Tensor
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker# The functions below simulate the pure-python torch functions in the
33*da0073e9SAndroid Build Coastguard Worker# torch.functional namespace. We use examples local to this file rather
34*da0073e9SAndroid Build Coastguard Worker# than any of the real examples implemented in Python since in the
35*da0073e9SAndroid Build Coastguard Worker# future those examples might get reimplemented in C++ for speed. This
36*da0073e9SAndroid Build Coastguard Worker# fake torch function allows us to verify that the dispatch rules work
37*da0073e9SAndroid Build Coastguard Worker# the same for a torch function implemented in C++ or Python.
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workerdef foo(a, b, c=None):
40*da0073e9SAndroid Build Coastguard Worker    """A function multiple arguments and an optional argument"""
41*da0073e9SAndroid Build Coastguard Worker    if has_torch_function((a, b, c)):
42*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(foo, (a, b, c), a, b, c=c)
43*da0073e9SAndroid Build Coastguard Worker    if c:
44*da0073e9SAndroid Build Coastguard Worker        return a + b + c
45*da0073e9SAndroid Build Coastguard Worker    return a + b
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerdef bar(a):
48*da0073e9SAndroid Build Coastguard Worker    """A function with one argument"""
49*da0073e9SAndroid Build Coastguard Worker    if has_torch_function((a,)):
50*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(bar, (a,), a)
51*da0073e9SAndroid Build Coastguard Worker    return a
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerdef baz(a, b):
54*da0073e9SAndroid Build Coastguard Worker    """A function with multiple arguments"""
55*da0073e9SAndroid Build Coastguard Worker    if has_torch_function((a, b)):
56*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(baz, (a, b), a, b)
57*da0073e9SAndroid Build Coastguard Worker    return a + b
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Workerdef quux(a):
60*da0073e9SAndroid Build Coastguard Worker    """Used to test that errors raised in user implementations get propagated"""
61*da0073e9SAndroid Build Coastguard Worker    if has_torch_function((a,)):
62*da0073e9SAndroid Build Coastguard Worker        return handle_torch_function(quux, (a,), a)
63*da0073e9SAndroid Build Coastguard Worker    return a
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that
66*da0073e9SAndroid Build Coastguard Worker# DiagonalTensor.__torch_function__ uses to determine which override
67*da0073e9SAndroid Build Coastguard Worker# function to call for a given torch API function.  The keys of the
68*da0073e9SAndroid Build Coastguard Worker# dictionary are function names in the torch API and the values are
69*da0073e9SAndroid Build Coastguard Worker# function implementations. Implementations are added to
70*da0073e9SAndroid Build Coastguard Worker# HANDLED_FUNCTION_DIAGONAL by decorating a python function with
71*da0073e9SAndroid Build Coastguard Worker# implements_diagonal. See the overrides immediately below the defintion
72*da0073e9SAndroid Build Coastguard Worker# of DiagonalTensor for usage examples.
73*da0073e9SAndroid Build Coastguard WorkerHANDLED_FUNCTIONS_DIAGONAL = {}
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Workerdef implements_diagonal(torch_function):
76*da0073e9SAndroid Build Coastguard Worker    """Register a torch function override for DiagonalTensor.
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    This decorator takes a function in the torch API as a
79*da0073e9SAndroid Build Coastguard Worker    parameter. Applying this decorator to a function adds that function
80*da0073e9SAndroid Build Coastguard Worker    as the registered override for the torch function passed as a
81*da0073e9SAndroid Build Coastguard Worker    parameter to the decorator. See DiagonalTensor.__torch_function__
82*da0073e9SAndroid Build Coastguard Worker    for the runtime dispatch implementation and the decorated functions
83*da0073e9SAndroid Build Coastguard Worker    immediately below DiagonalTensor for usage examples.
84*da0073e9SAndroid Build Coastguard Worker    """
85*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(torch_function)
86*da0073e9SAndroid Build Coastguard Worker    def decorator(func):
87*da0073e9SAndroid Build Coastguard Worker        HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func
88*da0073e9SAndroid Build Coastguard Worker        return func
89*da0073e9SAndroid Build Coastguard Worker    return decorator
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Workerclass DiagonalTensor:
92*da0073e9SAndroid Build Coastguard Worker    """A class with __torch_function__ and a specific diagonal representation
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    This class has limited utility and is mostly useful for verifying that the
95*da0073e9SAndroid Build Coastguard Worker    dispatch mechanism works as expected. It is based on the `DiagonalArray
96*da0073e9SAndroid Build Coastguard Worker    example`_ in the NumPy documentation.
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    Note that this class does *not* inherit from ``torch.tensor``, interaction
99*da0073e9SAndroid Build Coastguard Worker    with the pytorch dispatch system happens via the ``__torch_function__``
100*da0073e9SAndroid Build Coastguard Worker    protocol.
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    ``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has
103*da0073e9SAndroid Build Coastguard Worker    diagonal entries set to *value* and all other entries set to zero. The
104*da0073e9SAndroid Build Coastguard Worker    main functionality of ``DiagonalTensor`` is to provide a more compact
105*da0073e9SAndroid Build Coastguard Worker    string representation of a diagonal tensor than in the base tensor class:
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    >>> d = DiagonalTensor(5, 2)
108*da0073e9SAndroid Build Coastguard Worker    >>> d
109*da0073e9SAndroid Build Coastguard Worker    DiagonalTensor(N=5, value=2)
110*da0073e9SAndroid Build Coastguard Worker    >>> d.tensor()
111*da0073e9SAndroid Build Coastguard Worker    tensor([[2., 0., 0., 0., 0.],
112*da0073e9SAndroid Build Coastguard Worker            [0., 2., 0., 0., 0.],
113*da0073e9SAndroid Build Coastguard Worker            [0., 0., 2., 0., 0.],
114*da0073e9SAndroid Build Coastguard Worker            [0., 0., 0., 2., 0.],
115*da0073e9SAndroid Build Coastguard Worker            [0., 0., 0., 0., 2.]])
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    Note that to simplify testing, matrix multiplication of ``DiagonalTensor``
118*da0073e9SAndroid Build Coastguard Worker    returns 0:
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    >>> torch.mm(d, d)
121*da0073e9SAndroid Build Coastguard Worker    0
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    .. _DiagonalArray example:
124*da0073e9SAndroid Build Coastguard Worker        https://numpy.org/devdocs/user/basics.dispatch.html
125*da0073e9SAndroid Build Coastguard Worker    """
126*da0073e9SAndroid Build Coastguard Worker    # This is defined as a class attribute so that SubDiagonalTensor
127*da0073e9SAndroid Build Coastguard Worker    # below which subclasses DiagonalTensor can re-use DiagonalTensor's
128*da0073e9SAndroid Build Coastguard Worker    # __torch_function__ implementation.
129*da0073e9SAndroid Build Coastguard Worker    handled_functions = HANDLED_FUNCTIONS_DIAGONAL
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    def __init__(self, N, value):
132*da0073e9SAndroid Build Coastguard Worker        self._N = N
133*da0073e9SAndroid Build Coastguard Worker        self._i = value
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
136*da0073e9SAndroid Build Coastguard Worker        return f"DiagonalTensor(N={self._N}, value={self._i})"
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    def __array__(self):
139*da0073e9SAndroid Build Coastguard Worker        return self._i * np.eye(self._N)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    def tensor(self):
142*da0073e9SAndroid Build Coastguard Worker        return self._i * torch.eye(self._N)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    @classmethod
145*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
146*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
147*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
148*da0073e9SAndroid Build Coastguard Worker        if func not in cls.handled_functions:
149*da0073e9SAndroid Build Coastguard Worker            return NotImplemented
150*da0073e9SAndroid Build Coastguard Worker        return cls.handled_functions[func](*args, **kwargs)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other):
153*da0073e9SAndroid Build Coastguard Worker        return type(other) is type(self) and self._N == other._N and self._i == other._i
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(torch.mean)
156*da0073e9SAndroid Build Coastguard Workerdef mean(mat):
157*da0073e9SAndroid Build Coastguard Worker    return float(mat._i) / mat._N
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(torch.mm)
160*da0073e9SAndroid Build Coastguard Workerdef diagonal_mm(mat1, mat2):
161*da0073e9SAndroid Build Coastguard Worker    return 0
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(torch.div)
164*da0073e9SAndroid Build Coastguard Workerdef diagonal_div(input, other, out=None):
165*da0073e9SAndroid Build Coastguard Worker    return -1
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(torch.add)
168*da0073e9SAndroid Build Coastguard Workerdef add(mat1, mat2):
169*da0073e9SAndroid Build Coastguard Worker    raise ValueError
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(foo)
172*da0073e9SAndroid Build Coastguard Workerdef diagonal_foo(a, b, c=None):
173*da0073e9SAndroid Build Coastguard Worker    return -1
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(bar)
176*da0073e9SAndroid Build Coastguard Workerdef diagonal_bar(a):
177*da0073e9SAndroid Build Coastguard Worker    return -1
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(quux)
180*da0073e9SAndroid Build Coastguard Workerdef diagonal_quux(a):
181*da0073e9SAndroid Build Coastguard Worker    raise ValueError
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker# The dispatch table for SubTensor's __torch_function__ implementation.
184*da0073e9SAndroid Build Coastguard WorkerHANDLED_FUNCTIONS_SUB = {}
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Workerdef implements_sub(torch_function):
187*da0073e9SAndroid Build Coastguard Worker    "Register a torch function override for SubTensor"
188*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(torch_function)
189*da0073e9SAndroid Build Coastguard Worker    def decorator(func):
190*da0073e9SAndroid Build Coastguard Worker        HANDLED_FUNCTIONS_SUB[torch_function] = func
191*da0073e9SAndroid Build Coastguard Worker        return func
192*da0073e9SAndroid Build Coastguard Worker    return decorator
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Workerclass SubTensor(torch.Tensor):
195*da0073e9SAndroid Build Coastguard Worker    """A subclass of torch.Tensor use for testing __torch_function__ dispatch
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    This class has the property that matrix multiplication returns zero:
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    >>> s = SubTensor([[1, 1], [1, 1]])
200*da0073e9SAndroid Build Coastguard Worker    >>> torch.mm(s, s)
201*da0073e9SAndroid Build Coastguard Worker    0
202*da0073e9SAndroid Build Coastguard Worker    >>> t = torch.tensor([[1, 1], [1, 1]])
203*da0073e9SAndroid Build Coastguard Worker    >>> torch.mm(s, t)
204*da0073e9SAndroid Build Coastguard Worker    0
205*da0073e9SAndroid Build Coastguard Worker    >>> torch.mm(t, s)
206*da0073e9SAndroid Build Coastguard Worker    0
207*da0073e9SAndroid Build Coastguard Worker    >>> torch.mm(t, t)
208*da0073e9SAndroid Build Coastguard Worker    tensor([[2, 2],
209*da0073e9SAndroid Build Coastguard Worker            [2, 2]])
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    This is useful for testing that the semantics for overriding torch
212*da0073e9SAndroid Build Coastguard Worker    functions are working correctly.
213*da0073e9SAndroid Build Coastguard Worker    """
214*da0073e9SAndroid Build Coastguard Worker    @classmethod
215*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
216*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
217*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        if func not in HANDLED_FUNCTIONS_SUB:
220*da0073e9SAndroid Build Coastguard Worker            return NotImplemented
221*da0073e9SAndroid Build Coastguard Worker        return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Workerclass SubTensor2(torch.Tensor):
224*da0073e9SAndroid Build Coastguard Worker    pass
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Workerclass SubSubTensor2(SubTensor2):
227*da0073e9SAndroid Build Coastguard Worker    pass
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Workerclass SubTensor3(torch.Tensor):
230*da0073e9SAndroid Build Coastguard Worker    pass
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker@implements_sub(torch.mean)
233*da0073e9SAndroid Build Coastguard Workerdef sub_mean(mat):
234*da0073e9SAndroid Build Coastguard Worker    return 0
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker@implements_sub(torch.mm)
237*da0073e9SAndroid Build Coastguard Workerdef sub_mm(mat1, mat2):
238*da0073e9SAndroid Build Coastguard Worker    return -1
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker@implements_sub(bar)
241*da0073e9SAndroid Build Coastguard Workerdef sub_bar(mat):
242*da0073e9SAndroid Build Coastguard Worker    return 1
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker@implements_sub(torch.div)
245*da0073e9SAndroid Build Coastguard Workerdef sub_div(input, other, out=None):
246*da0073e9SAndroid Build Coastguard Worker    return NotImplemented
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
249*da0073e9SAndroid Build Coastguard WorkerHANDLED_FUNCTIONS_SUB_DIAGONAL = {}
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Workerdef implements_sub_diagonal(torch_function):
252*da0073e9SAndroid Build Coastguard Worker    "Register a torch function override for SubDiagonalTensor"
253*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(torch_function)
254*da0073e9SAndroid Build Coastguard Worker    def decorator(func):
255*da0073e9SAndroid Build Coastguard Worker        HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func
256*da0073e9SAndroid Build Coastguard Worker        return func
257*da0073e9SAndroid Build Coastguard Worker    return decorator
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Workerclass SubDiagonalTensor(DiagonalTensor):
260*da0073e9SAndroid Build Coastguard Worker    """A subclass of ``DiagonalTensor`` to test custom dispatch
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker    This class tests semantics for defining ``__torch_function__`` on a
263*da0073e9SAndroid Build Coastguard Worker    subclass of another class that defines ``__torch_function__``. The
264*da0073e9SAndroid Build Coastguard Worker    only difference compared with the superclass is that this class
265*da0073e9SAndroid Build Coastguard Worker    provides a slightly different repr as well as custom implementations
266*da0073e9SAndroid Build Coastguard Worker    of ``mean`` and ``mm``, scaling the mean by a factor of 10 and
267*da0073e9SAndroid Build Coastguard Worker    returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does.
268*da0073e9SAndroid Build Coastguard Worker    """
269*da0073e9SAndroid Build Coastguard Worker    handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
272*da0073e9SAndroid Build Coastguard Worker        return f"SubDiagonalTensor(N={self._N}, value={self._i})"
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(torch.mean)
276*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_mean(mat):
277*da0073e9SAndroid Build Coastguard Worker    return 10 * float(mat._i) / mat._N
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(bar)
280*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_bar(mat):
281*da0073e9SAndroid Build Coastguard Worker    return 0
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(torch.mm)
284*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_mm(mat1, mat2):
285*da0073e9SAndroid Build Coastguard Worker    return 1
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(torch.div)
288*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_div(input, other, out=None):
289*da0073e9SAndroid Build Coastguard Worker    return NotImplemented
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(foo)
292*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_foo(a, b, c=None):
293*da0073e9SAndroid Build Coastguard Worker    return NotImplemented
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
296*da0073e9SAndroid Build Coastguard WorkerHANDLED_FUNCTIONS_TENSOR_LIKE = {}
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker# Note: _triggered wrapper
300*da0073e9SAndroid Build Coastguard Worker# Dict that wraps the implementations from get_testing_overrides into another
301*da0073e9SAndroid Build Coastguard Worker# function with a _triggered slot/flag. The triggered flag is set when the
302*da0073e9SAndroid Build Coastguard Worker# implementation is called.
303*da0073e9SAndroid Build Coastguard WorkerWRAPPED_TRIGGERED_IMPLS = {}
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Workerdef triggered_wrapper(f):
307*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(f)
308*da0073e9SAndroid Build Coastguard Worker    def wrapped(*args, **kwargs):
309*da0073e9SAndroid Build Coastguard Worker        wrapped._triggered = True
310*da0073e9SAndroid Build Coastguard Worker        return f(*args, **kwargs)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    wrapped._triggered = False
313*da0073e9SAndroid Build Coastguard Worker    return wrapped
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Workerdef implements_tensor_like(torch_function):
316*da0073e9SAndroid Build Coastguard Worker    "Register a torch function override for TensorLike"
317*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(torch_function)
318*da0073e9SAndroid Build Coastguard Worker    def decorator(func):
319*da0073e9SAndroid Build Coastguard Worker        HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func
320*da0073e9SAndroid Build Coastguard Worker        return func
321*da0073e9SAndroid Build Coastguard Worker    return decorator
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Workerdef generate_tensor_like_torch_implementations():
324*da0073e9SAndroid Build Coastguard Worker    torch_vars = vars(torch)
325*da0073e9SAndroid Build Coastguard Worker    untested_funcs = []
326*da0073e9SAndroid Build Coastguard Worker    testing_overrides = get_testing_overrides()
327*da0073e9SAndroid Build Coastguard Worker    # test/test_cpp_api_parity.py monkeypatches torch.nn to have a new
328*da0073e9SAndroid Build Coastguard Worker    # function sample_functional.  Depending on what order you run pytest
329*da0073e9SAndroid Build Coastguard Worker    # collection, this may trigger the error here.  This is a hack to fix
330*da0073e9SAndroid Build Coastguard Worker    # the problem.  A more proper fix is to make the "not tested" check
331*da0073e9SAndroid Build Coastguard Worker    # a test on its own, and to make sure the monkeypatch is only installed
332*da0073e9SAndroid Build Coastguard Worker    # for the span of the relevant test (and deleted afterwards)
333*da0073e9SAndroid Build Coastguard Worker    testing_ignore = {"sample_functional", "autocast"}
334*da0073e9SAndroid Build Coastguard Worker    for namespace, funcs in get_overridable_functions().items():
335*da0073e9SAndroid Build Coastguard Worker        for func in funcs:
336*da0073e9SAndroid Build Coastguard Worker            if func not in testing_overrides and func.__name__ not in testing_ignore:
337*da0073e9SAndroid Build Coastguard Worker                untested_funcs.append(f"{namespace}.{func.__name__}")
338*da0073e9SAndroid Build Coastguard Worker    msg = (
339*da0073e9SAndroid Build Coastguard Worker        "The following functions are not tested for __torch_function__ "
340*da0073e9SAndroid Build Coastguard Worker        "support, please ensure there is an entry in the dict returned by "
341*da0073e9SAndroid Build Coastguard Worker        "torch.overrides.get_testing_overrides for this function or if a "
342*da0073e9SAndroid Build Coastguard Worker        "__torch_function__ override does not make sense, add an entry to "
343*da0073e9SAndroid Build Coastguard Worker        "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}"
344*da0073e9SAndroid Build Coastguard Worker    )
345*da0073e9SAndroid Build Coastguard Worker    assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
346*da0073e9SAndroid Build Coastguard Worker    for func, override in testing_overrides.items():
347*da0073e9SAndroid Build Coastguard Worker        # decorate the overrides with implements_tensor_like if it's not a
348*da0073e9SAndroid Build Coastguard Worker        # torch.Tensor method
349*da0073e9SAndroid Build Coastguard Worker        wrapped = triggered_wrapper(override)
350*da0073e9SAndroid Build Coastguard Worker        # See note: "_triggered wrapper"
351*da0073e9SAndroid Build Coastguard Worker        WRAPPED_TRIGGERED_IMPLS[func] = wrapped
352*da0073e9SAndroid Build Coastguard Worker        if is_tensor_method_or_property(func):
353*da0073e9SAndroid Build Coastguard Worker            implements_sub(func)(wrapped)
354*da0073e9SAndroid Build Coastguard Worker        else:
355*da0073e9SAndroid Build Coastguard Worker            implements_tensor_like(func)(wrapped)
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Workergenerate_tensor_like_torch_implementations()
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Workerclass TensorLike:
360*da0073e9SAndroid Build Coastguard Worker    """A class that overrides the full torch API
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker    This class is used to explicitly test that the full torch.tensor API
363*da0073e9SAndroid Build Coastguard Worker    can be overriden with a class that defines __torch_function__.
364*da0073e9SAndroid Build Coastguard Worker    """
365*da0073e9SAndroid Build Coastguard Worker    @classmethod
366*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
367*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
368*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        if func not in HANDLED_FUNCTIONS_TENSOR_LIKE:
371*da0073e9SAndroid Build Coastguard Worker            return NotImplemented
372*da0073e9SAndroid Build Coastguard Worker        # In this case _torch_function_ should override TensorLike objects
373*da0073e9SAndroid Build Coastguard Worker        return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs)
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Workerclass TestTorchFunctionOverride(TestCase):
376*da0073e9SAndroid Build Coastguard Worker    @classmethod
377*da0073e9SAndroid Build Coastguard Worker    def setUpClass(cls):
378*da0073e9SAndroid Build Coastguard Worker        cls._stack = contextlib.ExitStack()
379*da0073e9SAndroid Build Coastguard Worker        if TEST_WITH_TORCHDYNAMO:
380*da0073e9SAndroid Build Coastguard Worker            # Add classes to the wrapped tensor subclasses
381*da0073e9SAndroid Build Coastguard Worker            @contextlib.contextmanager
382*da0073e9SAndroid Build Coastguard Worker            def setup_subclasses():
383*da0073e9SAndroid Build Coastguard Worker                old = set(torch._dynamo.config.traceable_tensor_subclasses)
384*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor)
385*da0073e9SAndroid Build Coastguard Worker                try:
386*da0073e9SAndroid Build Coastguard Worker                    yield
387*da0073e9SAndroid Build Coastguard Worker                finally:
388*da0073e9SAndroid Build Coastguard Worker                    torch._dynamo.config.traceable_tensor_subclasses.clear()
389*da0073e9SAndroid Build Coastguard Worker                    torch._dynamo.config.traceable_tensor_subclasses.update(old)
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker            cls._stack.enter_context(setup_subclasses())
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    @classmethod
394*da0073e9SAndroid Build Coastguard Worker    def tearDownClass(cls):
395*da0073e9SAndroid Build Coastguard Worker        cls._stack.close()
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker    def test_mean_semantics(self):
398*da0073e9SAndroid Build Coastguard Worker        """Test that a function with one argument can be overridden"""
399*da0073e9SAndroid Build Coastguard Worker        t1 = DiagonalTensor(5, 2)
400*da0073e9SAndroid Build Coastguard Worker        t2 = SubTensor([[1, 2], [1, 2]])
401*da0073e9SAndroid Build Coastguard Worker        t3 = SubDiagonalTensor(5, 2)
402*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mean(t1), 0.4)
403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bar(t1), -1)
404*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mean(t2), 0)
405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bar(t2), 1)
406*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mean(t3), 4.0)
407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bar(t3), 0)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker    def test_has_torch_function_non_sequence(self):
410*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "expected a sequence"):
411*da0073e9SAndroid Build Coastguard Worker            has_torch_function(object())
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker    def test_mm_semantics(self):
414*da0073e9SAndroid Build Coastguard Worker        """Test that a function with multiple arguments can be overridden"""
415*da0073e9SAndroid Build Coastguard Worker        t1 = DiagonalTensor(5, 2)
416*da0073e9SAndroid Build Coastguard Worker        t2 = torch.eye(5) * 2
417*da0073e9SAndroid Build Coastguard Worker        t3 = SubTensor([[1, 2], [1, 2]])
418*da0073e9SAndroid Build Coastguard Worker        t4 = SubDiagonalTensor(5, 2)
419*da0073e9SAndroid Build Coastguard Worker        # only DiagonalTensor so should always get DiagonalTensor result
420*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t1, t1), 0)
421*da0073e9SAndroid Build Coastguard Worker        # tensor and DiagonalTensor, always return DiagonalTensor result
422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t1, t2), 0)
423*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t2, t1), 0)
424*da0073e9SAndroid Build Coastguard Worker        # only SubTensor so should always get SubTensor result
425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t3, t3), -1)
426*da0073e9SAndroid Build Coastguard Worker        # tensor and SubTensor so should always get SubTensor result
427*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t3, t2), -1)
428*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t2, t3), -1)
429*da0073e9SAndroid Build Coastguard Worker        # DiagonalTensor and SubTensor are unrelated classes so the result
430*da0073e9SAndroid Build Coastguard Worker        # depends on which argument appears first
431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t3, t1), -1)
432*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t1, t3), 0)
433*da0073e9SAndroid Build Coastguard Worker        # SubDiagonalTensor should take precedence over DiagonalTensor
434*da0073e9SAndroid Build Coastguard Worker        # but should behave otherwise the same as DiagonalTensor
435*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t4, t4), 1)
436*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t4, t1), 1)
437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t1, t4), 1)
438*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t4, t2), 1)
439*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t2, t4), 1)
440*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t3, t4), -1)
441*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.mm(t4, t3), 1)
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    def test_precedence_semantics(self):
444*da0073e9SAndroid Build Coastguard Worker        """Test semantics for __torch_function__ for functions that take
445*da0073e9SAndroid Build Coastguard Worker        multiple arguments
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        For functions that take multiple arguments, the appropriate
448*da0073e9SAndroid Build Coastguard Worker        __torch_function__ implementation to call is determined by
449*da0073e9SAndroid Build Coastguard Worker        examining the types of the arguments. The precedence order is
450*da0073e9SAndroid Build Coastguard Worker        left-to-right in the argument list, except subclasses are always
451*da0073e9SAndroid Build Coastguard Worker        checked before superclasses. The first result of calling the
452*da0073e9SAndroid Build Coastguard Worker        implementations in precedence order that is not NotImplemented
453*da0073e9SAndroid Build Coastguard Worker        is returned to the user. If all implementations return
454*da0073e9SAndroid Build Coastguard Worker        NotImplemented, a TypeError is raised.
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        All cases are tested with functions implemented in C++ and
457*da0073e9SAndroid Build Coastguard Worker        either foo or baz, which are python functions defined above that
458*da0073e9SAndroid Build Coastguard Worker        are instrumented to obey the same dispatch rules as the
459*da0073e9SAndroid Build Coastguard Worker        functions in torch.functional.
460*da0073e9SAndroid Build Coastguard Worker        """
461*da0073e9SAndroid Build Coastguard Worker        # DiagonalTensor has a valid override and SubDiagonal has an
462*da0073e9SAndroid Build Coastguard Worker        # override that returns NotImplemented so we should call the
463*da0073e9SAndroid Build Coastguard Worker        # DiagonalTensor implementation, returning -1
464*da0073e9SAndroid Build Coastguard Worker        t1 = DiagonalTensor(5, 2)
465*da0073e9SAndroid Build Coastguard Worker        t2 = SubDiagonalTensor(5, 2)
466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.div(t1, t2), -1)
467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.div(t2, t1), -1)
468*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(t1, t2), -1)
469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(t2, t1), -1)
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        # SubTensor has an implementation that returns NotImplemented as
472*da0073e9SAndroid Build Coastguard Worker        # well so it should behave exactly like SubDiagonalTensor in the
473*da0073e9SAndroid Build Coastguard Worker        # test above
474*da0073e9SAndroid Build Coastguard Worker        t3 = SubTensor([[1, 2], [1, 2]])
475*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.div(t1, t3), -1)
476*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.div(t3, t1), -1)
477*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(t1, t3), -1)
478*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(t3, t1), -1)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker        # div between SubTensor and SubDiagonalTensor should raise
481*da0073e9SAndroid Build Coastguard Worker        # TypeError since both have an implementation that
482*da0073e9SAndroid Build Coastguard Worker        # explicitly returns NotImplemented
483*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
484*da0073e9SAndroid Build Coastguard Worker            torch.div(t2, t3)
485*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
486*da0073e9SAndroid Build Coastguard Worker            torch.div(t3, t2)
487*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
488*da0073e9SAndroid Build Coastguard Worker            foo(t2, t3)
489*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
490*da0073e9SAndroid Build Coastguard Worker            foo(t3, t2)
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker        # none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a
493*da0073e9SAndroid Build Coastguard Worker        # mul or a baz implementation so all ops should raise TypeError
494*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
495*da0073e9SAndroid Build Coastguard Worker            torch.mul(t1, t1)
496*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
497*da0073e9SAndroid Build Coastguard Worker            torch.mul(t1, t2)
498*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
499*da0073e9SAndroid Build Coastguard Worker            torch.mul(t1, t3)
500*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
501*da0073e9SAndroid Build Coastguard Worker            torch.mul(t2, t1)
502*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
503*da0073e9SAndroid Build Coastguard Worker            torch.mul(t2, t2)
504*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
505*da0073e9SAndroid Build Coastguard Worker            torch.mul(t2, t3)
506*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
507*da0073e9SAndroid Build Coastguard Worker            torch.mul(t3, t1)
508*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
509*da0073e9SAndroid Build Coastguard Worker            torch.mul(t3, t2)
510*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
511*da0073e9SAndroid Build Coastguard Worker            torch.mul(t3, t3)
512*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
513*da0073e9SAndroid Build Coastguard Worker            baz(t1, t1)
514*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
515*da0073e9SAndroid Build Coastguard Worker            baz(t1, t2)
516*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
517*da0073e9SAndroid Build Coastguard Worker            baz(t1, t3)
518*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
519*da0073e9SAndroid Build Coastguard Worker            baz(t2, t1)
520*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
521*da0073e9SAndroid Build Coastguard Worker            baz(t2, t2)
522*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
523*da0073e9SAndroid Build Coastguard Worker            baz(t2, t3)
524*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
525*da0073e9SAndroid Build Coastguard Worker            baz(t3, t1)
526*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
527*da0073e9SAndroid Build Coastguard Worker            baz(t3, t2)
528*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
529*da0073e9SAndroid Build Coastguard Worker            baz(t3, t3)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker    def test_user_implementation_raises(self):
532*da0073e9SAndroid Build Coastguard Worker        """Test that errors raised in user implementations propagate correctly"""
533*da0073e9SAndroid Build Coastguard Worker        t1 = DiagonalTensor(5, 2)
534*da0073e9SAndroid Build Coastguard Worker        t2 = DiagonalTensor(5, 2)
535*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
536*da0073e9SAndroid Build Coastguard Worker            torch.add(t1, t2)
537*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
538*da0073e9SAndroid Build Coastguard Worker            quux(t1)
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker    def test_tensor_subclass_propagation(self):
541*da0073e9SAndroid Build Coastguard Worker        """this test exercises the functionality described in
542*da0073e9SAndroid Build Coastguard Worker        docs/source/notes/extending.rst#subclassing-torchtensor"""
543*da0073e9SAndroid Build Coastguard Worker        t1 = torch.tensor([5])
544*da0073e9SAndroid Build Coastguard Worker        t2 = torch.tensor([6])
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker        s1 = SubTensor2([5])
547*da0073e9SAndroid Build Coastguard Worker        s2 = SubTensor2([6])
548*da0073e9SAndroid Build Coastguard Worker
549*da0073e9SAndroid Build Coastguard Worker        ss1 = SubSubTensor2([5])
550*da0073e9SAndroid Build Coastguard Worker        ss2 = SubSubTensor2([6])
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        sn1 = SubTensor3([5])
553*da0073e9SAndroid Build Coastguard Worker        sn2 = SubTensor3([6])
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker        # Check that leaf subclass is kept regardless of order
556*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s1 + t2, SubTensor2))
557*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(t1 + s2, SubTensor2))
558*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s1 + s2, SubTensor2))
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker        # Check indexing subclass is kept
561*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s1[0], SubTensor2))
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker        # Check case for subclass of subclass.
564*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
565*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(ss1 + s2, SubSubTensor2))
566*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(s1 + ss2, SubSubTensor2))
567*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
568*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(ss1 + t2, SubSubTensor2))
569*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(t1 + ss2, SubSubTensor2))
570*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(ss1[0], SubSubTensor2))
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker        # Make sure unrelated class trees are not merged.
573*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
574*da0073e9SAndroid Build Coastguard Worker            s1 + sn2
575*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
576*da0073e9SAndroid Build Coastguard Worker            sn1 + s2
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker    def test_base(self):
579*da0073e9SAndroid Build Coastguard Worker        # https://github.com/szagoruyko/pytorchviz/issues/65
580*da0073e9SAndroid Build Coastguard Worker        class DummyTensor(torch.Tensor):
581*da0073e9SAndroid Build Coastguard Worker            pass
582*da0073e9SAndroid Build Coastguard Worker
583*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1)
584*da0073e9SAndroid Build Coastguard Worker        c = DummyTensor(a)
585*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(c._is_view())
586*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(c._base is a)
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker    def test_grad(self):
589*da0073e9SAndroid Build Coastguard Worker        # Previously, Tensor-like objects that did not subclass from Tensor
590*da0073e9SAndroid Build Coastguard Worker        # did not get wrapped into unary tuples before being passed into
591*da0073e9SAndroid Build Coastguard Worker        # handle_torch_function, in contradiction with how Tensor-likes
592*da0073e9SAndroid Build Coastguard Worker        # were handled
593*da0073e9SAndroid Build Coastguard Worker        #
594*da0073e9SAndroid Build Coastguard Worker        # NB: this asserts that the arguments get normalized into a tuple
595*da0073e9SAndroid Build Coastguard Worker        # before entering the torch function handler; it could go the
596*da0073e9SAndroid Build Coastguard Worker        # other way but beware https://github.com/pytorch/pytorch/issues/76037
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker        class Dummy:
599*da0073e9SAndroid Build Coastguard Worker            @classmethod
600*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
601*da0073e9SAndroid Build Coastguard Worker                inputs, outputs = args
602*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(inputs, (x,))
603*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(outputs, (x,))
604*da0073e9SAndroid Build Coastguard Worker                return -1
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker        x = Dummy()
607*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.autograd.grad(x, x), -1)
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker    def test_pow_rpow(self):
610*da0073e9SAndroid Build Coastguard Worker        class NothingImplemented(torch.Tensor):
611*da0073e9SAndroid Build Coastguard Worker            @classmethod
612*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
613*da0073e9SAndroid Build Coastguard Worker                return NotImplemented
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker        class RPowOnly(torch.Tensor):
616*da0073e9SAndroid Build Coastguard Worker            @classmethod
617*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
618*da0073e9SAndroid Build Coastguard Worker                if func is torch.Tensor.__rpow__:
619*da0073e9SAndroid Build Coastguard Worker                    return -1
620*da0073e9SAndroid Build Coastguard Worker                return NotImplemented
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(NothingImplemented() ** RPowOnly(), -1)
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Workerdef generate_tensor_like_override_tests(cls):
626*da0073e9SAndroid Build Coastguard Worker    from torch.testing._internal.generated.annotated_fn_args import annotated_args
627*da0073e9SAndroid Build Coastguard Worker
628*da0073e9SAndroid Build Coastguard Worker    def test_generator(func, override):
629*da0073e9SAndroid Build Coastguard Worker        # If func corresponds to a torch.Tensor method or property.
630*da0073e9SAndroid Build Coastguard Worker        if is_tensor_method_or_property(func):
631*da0073e9SAndroid Build Coastguard Worker            # Generate an instance by using SubTensor,
632*da0073e9SAndroid Build Coastguard Worker            def instance_gen():
633*da0073e9SAndroid Build Coastguard Worker                return SubTensor([5])
634*da0073e9SAndroid Build Coastguard Worker        else:
635*da0073e9SAndroid Build Coastguard Worker            # Otherwise, TensorLike.
636*da0073e9SAndroid Build Coastguard Worker            def instance_gen():
637*da0073e9SAndroid Build Coastguard Worker                return TensorLike()
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker        # FIXME The following code does not support kwonly args without defaults.
640*da0073e9SAndroid Build Coastguard Worker        # The fix is easy, as one just needs to save these args when generating the variable
641*da0073e9SAndroid Build Coastguard Worker        # annotated_args. The problem is that, if one does so, one finds a number
642*da0073e9SAndroid Build Coastguard Worker        # of functions that have problematic signatures in native_functions.yaml.
643*da0073e9SAndroid Build Coastguard Worker        # Fixing these would be BC breaking, so hence this terrible hack
644*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/67008
645*da0073e9SAndroid Build Coastguard Worker        kwargs = {}
646*da0073e9SAndroid Build Coastguard Worker        if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__:
647*da0073e9SAndroid Build Coastguard Worker            kwargs = {"upper": True}
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker        func_args = []
650*da0073e9SAndroid Build Coastguard Worker        is_method = is_tensor_method_or_property(func)
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Worker        def _simple_type_parser(func, arg_name, arg_type):
653*da0073e9SAndroid Build Coastguard Worker            # Guess valid input to aten function based on type of argument
654*da0073e9SAndroid Build Coastguard Worker            if arg_type == "Tensor":
655*da0073e9SAndroid Build Coastguard Worker                return instance_gen()
656*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "TensorList" or arg_type == "ITensorListRef":
657*da0073e9SAndroid Build Coastguard Worker                return [instance_gen(), instance_gen()]
658*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "c10::List<::std::optional<Tensor>>":
659*da0073e9SAndroid Build Coastguard Worker                return [instance_gen(), instance_gen()]
660*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef":
661*da0073e9SAndroid Build Coastguard Worker                size = arg.get("size", 2)
662*da0073e9SAndroid Build Coastguard Worker                if size == 1:
663*da0073e9SAndroid Build Coastguard Worker                    return 1
664*da0073e9SAndroid Build Coastguard Worker                else:
665*da0073e9SAndroid Build Coastguard Worker                    return [1] * size
666*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "Scalar":
667*da0073e9SAndroid Build Coastguard Worker                return 3.5
668*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "bool":
669*da0073e9SAndroid Build Coastguard Worker                return False
670*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "Dimname":
671*da0073e9SAndroid Build Coastguard Worker                return ""
672*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "DimnameList":
673*da0073e9SAndroid Build Coastguard Worker                return [""]
674*da0073e9SAndroid Build Coastguard Worker            elif arg_type.startswith("int"):
675*da0073e9SAndroid Build Coastguard Worker                return 0
676*da0073e9SAndroid Build Coastguard Worker            elif arg_type in {"Stream"}:
677*da0073e9SAndroid Build Coastguard Worker                return torch.Stream()
678*da0073e9SAndroid Build Coastguard Worker            elif arg_type.startswith("float") or arg_type == "double":
679*da0073e9SAndroid Build Coastguard Worker                return 1.0
680*da0073e9SAndroid Build Coastguard Worker            elif arg_type in {"Generator", "MemoryFormat", "TensorOptions"}:
681*da0073e9SAndroid Build Coastguard Worker                return None
682*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "ScalarType":
683*da0073e9SAndroid Build Coastguard Worker                return torch.float32
684*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "c10::string_view":
685*da0073e9SAndroid Build Coastguard Worker                return ""
686*da0073e9SAndroid Build Coastguard Worker            elif arg_type == "SymInt":
687*da0073e9SAndroid Build Coastguard Worker                # TODO: generate actual SymbolicInt
688*da0073e9SAndroid Build Coastguard Worker                return 1
689*da0073e9SAndroid Build Coastguard Worker            else:
690*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
691*da0073e9SAndroid Build Coastguard Worker                    f"Unsupported argument type {arg_type} for {arg_name} of function {func}"
692*da0073e9SAndroid Build Coastguard Worker                )
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker        if func in annotated_args:
695*da0073e9SAndroid Build Coastguard Worker            for arg in annotated_args[func]:
696*da0073e9SAndroid Build Coastguard Worker                # Guess valid input to aten function based on type of argument
697*da0073e9SAndroid Build Coastguard Worker                t = arg["simple_type"]
698*da0073e9SAndroid Build Coastguard Worker                if t.endswith("?"):
699*da0073e9SAndroid Build Coastguard Worker                    t = t[:-1]
700*da0073e9SAndroid Build Coastguard Worker                if t == "Tensor" and is_method and arg["name"] == "self":
701*da0073e9SAndroid Build Coastguard Worker                    # See "Note: properties and __get__"
702*da0073e9SAndroid Build Coastguard Worker                    func = func.__get__(instance_gen())
703*da0073e9SAndroid Build Coastguard Worker                    continue
704*da0073e9SAndroid Build Coastguard Worker                arg_to_add = _simple_type_parser(func, arg["name"], t)
705*da0073e9SAndroid Build Coastguard Worker                if "is_kwarg_only" in arg and arg["is_kwarg_only"] == str(True):
706*da0073e9SAndroid Build Coastguard Worker                    kwargs[arg["name"]] = arg_to_add
707*da0073e9SAndroid Build Coastguard Worker                else:
708*da0073e9SAndroid Build Coastguard Worker                    func_args.append(arg_to_add)
709*da0073e9SAndroid Build Coastguard Worker        else:
710*da0073e9SAndroid Build Coastguard Worker            args = inspect.getfullargspec(override)
711*da0073e9SAndroid Build Coastguard Worker            try:
712*da0073e9SAndroid Build Coastguard Worker                func_args = inspect.getfullargspec(func)
713*da0073e9SAndroid Build Coastguard Worker                # Remove annotations from argspec
714*da0073e9SAndroid Build Coastguard Worker                func_args = type(func_args)(**{**func_args, 'annotations': None})
715*da0073e9SAndroid Build Coastguard Worker                if func_args != args:
716*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(f"Override for {func} doesn't match its argspec.\n"
717*da0073e9SAndroid Build Coastguard Worker                                       + f"Original: {inspect.signature(func)}\n"
718*da0073e9SAndroid Build Coastguard Worker                                       + f"Override: {inspect.signature(override)}")
719*da0073e9SAndroid Build Coastguard Worker            except TypeError:
720*da0073e9SAndroid Build Coastguard Worker                pass
721*da0073e9SAndroid Build Coastguard Worker            nargs = len(args.args)
722*da0073e9SAndroid Build Coastguard Worker            if args.defaults is not None:
723*da0073e9SAndroid Build Coastguard Worker                nargs -= len(args.defaults)
724*da0073e9SAndroid Build Coastguard Worker            func_args = [instance_gen() for _ in range(nargs)]
725*da0073e9SAndroid Build Coastguard Worker            if args.varargs is not None:
726*da0073e9SAndroid Build Coastguard Worker                func_args += [instance_gen(), instance_gen()]
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker        def test(self):
729*da0073e9SAndroid Build Coastguard Worker            ret = func(*func_args, **kwargs)
730*da0073e9SAndroid Build Coastguard Worker            # ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
731*da0073e9SAndroid Build Coastguard Worker            # This is currently the best check but doesn't work for, for example,
732*da0073e9SAndroid Build Coastguard Worker            # Tensor.__add__ because it redirects to Tensor.add.
733*da0073e9SAndroid Build Coastguard Worker            # See note "_triggered wrapper"
734*da0073e9SAndroid Build Coastguard Worker            if not is_method or ret is None:
735*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered)
736*da0073e9SAndroid Build Coastguard Worker                return
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ret, -1)
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker        return test
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    for func, override in get_testing_overrides().items():
743*da0073e9SAndroid Build Coastguard Worker        test_method = test_generator(func, override)
744*da0073e9SAndroid Build Coastguard Worker        if func.__name__ == "__get__":
745*da0073e9SAndroid Build Coastguard Worker            # Note: properties and __get__
746*da0073e9SAndroid Build Coastguard Worker            # __get__ is part of the descriptor protocol.
747*da0073e9SAndroid Build Coastguard Worker            # https://docs.python.org/3/howto/descriptor.html
748*da0073e9SAndroid Build Coastguard Worker            # This is used for properties of the form
749*da0073e9SAndroid Build Coastguard Worker            # torch.Tensor.<property>, with the method __get__
750*da0073e9SAndroid Build Coastguard Worker            # In this case we get the property name in two ways:
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker            # This case for properties defined in C.
753*da0073e9SAndroid Build Coastguard Worker            module = getattr(
754*da0073e9SAndroid Build Coastguard Worker                func.__self__,
755*da0073e9SAndroid Build Coastguard Worker                "__qualname__",
756*da0073e9SAndroid Build Coastguard Worker                None
757*da0073e9SAndroid Build Coastguard Worker            )
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker            # This one for properties defined in Python.
760*da0073e9SAndroid Build Coastguard Worker            if module is None:
761*da0073e9SAndroid Build Coastguard Worker                module = "Tensor." + func.__self__.fget.__name__
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker            # Unfortunately I couldn't find a way to unify these two cases
764*da0073e9SAndroid Build Coastguard Worker            # and there is no way for general descriptors.
765*da0073e9SAndroid Build Coastguard Worker        elif is_tensor_method_or_property(func):
766*da0073e9SAndroid Build Coastguard Worker            module = "Tensor"
767*da0073e9SAndroid Build Coastguard Worker        else:
768*da0073e9SAndroid Build Coastguard Worker            module = func.__module__
769*da0073e9SAndroid Build Coastguard Worker        if module:
770*da0073e9SAndroid Build Coastguard Worker            name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
771*da0073e9SAndroid Build Coastguard Worker        else:
772*da0073e9SAndroid Build Coastguard Worker            name = f'test_{func.__name__}'
773*da0073e9SAndroid Build Coastguard Worker        test_method.__name__ = name
774*da0073e9SAndroid Build Coastguard Worker        setattr(cls, name, test_method)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Workergenerate_tensor_like_override_tests(TestTorchFunctionOverride)
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Workerclass Wrapper:
779*da0073e9SAndroid Build Coastguard Worker    "Basic data container that knows how to unwrap itself"
780*da0073e9SAndroid Build Coastguard Worker    def __init__(self, data):
781*da0073e9SAndroid Build Coastguard Worker        self.__dict__["_data"] = data
782*da0073e9SAndroid Build Coastguard Worker        self.__dict__["used_attrs"] = set()
783*da0073e9SAndroid Build Coastguard Worker        self.__dict__["used_calls"] = set()
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, name):
786*da0073e9SAndroid Build Coastguard Worker        if name in self.__dict__:
787*da0073e9SAndroid Build Coastguard Worker            return self.__dict__[name]
788*da0073e9SAndroid Build Coastguard Worker        self.used_attrs.add(name)
789*da0073e9SAndroid Build Coastguard Worker
790*da0073e9SAndroid Build Coastguard Worker        val = getattr(self._data, name)
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker        # If it's a method
793*da0073e9SAndroid Build Coastguard Worker        if not isinstance(val, torch.device) and callable(val):
794*da0073e9SAndroid Build Coastguard Worker            c = getattr(type(self._data), name)
795*da0073e9SAndroid Build Coastguard Worker            # Don't append self to args if classmethod/staticmethod
796*da0073e9SAndroid Build Coastguard Worker            if c is val:
797*da0073e9SAndroid Build Coastguard Worker                return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw))
798*da0073e9SAndroid Build Coastguard Worker            # Otherwise append self to args
799*da0073e9SAndroid Build Coastguard Worker            return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw))
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker        return wrap(val)
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker    def __setattr__(self, name, value):
804*da0073e9SAndroid Build Coastguard Worker        if name in self.__dict__:
805*da0073e9SAndroid Build Coastguard Worker            self.__dict__[name] = value
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        self.used_attrs.add(name)
808*da0073e9SAndroid Build Coastguard Worker        setattr(self._data, name, unwrap(value))
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker    def __setitem__(self, key, value):
811*da0073e9SAndroid Build Coastguard Worker        self._data[unwrap(key)] = unwrap(value)
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, key):
814*da0073e9SAndroid Build Coastguard Worker        return wrap(self._data[unwrap(key)])
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    @classmethod
817*da0073e9SAndroid Build Coastguard Worker    def __torch_function__(cls, func, types, args=(), kwargs=None):
818*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
819*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
820*da0073e9SAndroid Build Coastguard Worker        # Find an instance of this class in the arguments
821*da0073e9SAndroid Build Coastguard Worker        args_of_this_cls = []
822*da0073e9SAndroid Build Coastguard Worker        for a in args:
823*da0073e9SAndroid Build Coastguard Worker            if isinstance(a, cls):
824*da0073e9SAndroid Build Coastguard Worker                args_of_this_cls.append(a)
825*da0073e9SAndroid Build Coastguard Worker            elif isinstance(a, collections.abc.Sequence):
826*da0073e9SAndroid Build Coastguard Worker                args_of_this_cls.extend(el for el in a if isinstance(el, cls))
827*da0073e9SAndroid Build Coastguard Worker        assert len(args_of_this_cls) > 0
828*da0073e9SAndroid Build Coastguard Worker        for a in args_of_this_cls:
829*da0073e9SAndroid Build Coastguard Worker            a.used_calls.add(func)
830*da0073e9SAndroid Build Coastguard Worker        args = unwrap(tuple(args))
831*da0073e9SAndroid Build Coastguard Worker        kwargs = {k: unwrap(v) for k, v in kwargs.items()}
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        return wrap(func(*args, **kwargs))
834*da0073e9SAndroid Build Coastguard Worker
835*da0073e9SAndroid Build Coastguard Worker    def __add__(self, other):
836*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.add, (Wrapper,), (self, other))
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker    def __mul__(self, other):
839*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.mul, (Wrapper,), (self, other))
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker    def __sub__(self, other):
842*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.sub, (Wrapper,), (self, other))
843*da0073e9SAndroid Build Coastguard Worker
844*da0073e9SAndroid Build Coastguard Worker    def __truediv__(self, other):
845*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other))
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker    def __floordiv__(self, other):
848*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other))
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker    def __ge__(self, other):
851*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.ge, (Wrapper,), (self, other))
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker    def __gt__(self, other):
854*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.gt, (Wrapper,), (self, other))
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    def __lt__(self, other):
857*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.lt, (Wrapper,), (self, other))
858*da0073e9SAndroid Build Coastguard Worker
859*da0073e9SAndroid Build Coastguard Worker    def __le__(self, other):
860*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.le, (Wrapper,), (self, other))
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other):
863*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.eq, (Wrapper,), (self, other))
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker    def __ne__(self, other):
866*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.ne, (Wrapper,), (self, other))
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker    def __bool__(self):
869*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,))
870*da0073e9SAndroid Build Coastguard Worker
871*da0073e9SAndroid Build Coastguard Worker    def __int__(self):
872*da0073e9SAndroid Build Coastguard Worker        return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,))
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
875*da0073e9SAndroid Build Coastguard Worker        return len(self._data)
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker# unwrap inputs if necessary
879*da0073e9SAndroid Build Coastguard Workerdef unwrap(v):
880*da0073e9SAndroid Build Coastguard Worker    if type(v) in {tuple, list}:
881*da0073e9SAndroid Build Coastguard Worker        return type(v)(unwrap(vi) for vi in v)
882*da0073e9SAndroid Build Coastguard Worker
883*da0073e9SAndroid Build Coastguard Worker    return v._data if isinstance(v, Wrapper) else v
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker# wrap inputs if necessary
886*da0073e9SAndroid Build Coastguard Workerdef wrap(v):
887*da0073e9SAndroid Build Coastguard Worker    if type(v) in {tuple, list}:
888*da0073e9SAndroid Build Coastguard Worker        return type(v)(wrap(vi) for vi in v)
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    return Wrapper(v) if isinstance(v, torch.Tensor) else v
891*da0073e9SAndroid Build Coastguard Worker
892*da0073e9SAndroid Build Coastguard Workerclass TestEinsumOverride(TestCase):
893*da0073e9SAndroid Build Coastguard Worker    "Regression test for gh-38479"
894*da0073e9SAndroid Build Coastguard Worker    def test_wrapper(self):
895*da0073e9SAndroid Build Coastguard Worker        x = Wrapper(torch.randn(5))
896*da0073e9SAndroid Build Coastguard Worker        y = Wrapper(torch.randn(4))
897*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.einsum('i,j->ij', x, y)._data,
898*da0073e9SAndroid Build Coastguard Worker                         torch.ger(x, y)._data)
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker        # in the old einsum interface, `operands` is a list
901*da0073e9SAndroid Build Coastguard Worker        a = Wrapper(torch.randn(2, 3))
902*da0073e9SAndroid Build Coastguard Worker        b = Wrapper(torch.randn(5, 3, 7))
903*da0073e9SAndroid Build Coastguard Worker        c = Wrapper(torch.randn(2, 7))
904*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data,
905*da0073e9SAndroid Build Coastguard Worker                         torch.nn.functional.bilinear(a, c, b)._data)
906*da0073e9SAndroid Build Coastguard Worker
907*da0073e9SAndroid Build Coastguard Workerclass TestGradCheckOverride(TestCase):
908*da0073e9SAndroid Build Coastguard Worker    "Test that wrappers work with gradcheck."
909*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck(self):
910*da0073e9SAndroid Build Coastguard Worker        from torch.testing._internal.common_utils import gradcheck, gradgradcheck
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker        def run_test(fast_mode):
913*da0073e9SAndroid Build Coastguard Worker            a = wrap(torch.tensor(5.0, dtype=torch.double))
914*da0073e9SAndroid Build Coastguard Worker            b = wrap(torch.tensor(6.0, dtype=torch.double))
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker            a.requires_grad = True
917*da0073e9SAndroid Build Coastguard Worker            b.requires_grad = True
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker            gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
920*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode)
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker            total_used_attrs = a.used_attrs.union(b.used_attrs)
923*da0073e9SAndroid Build Coastguard Worker            total_used_calls = a.used_calls.union(b.used_calls)
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker            # These attributes (and the functions below) may change
926*da0073e9SAndroid Build Coastguard Worker            # if the gradcheck implementation changes. It's best to
927*da0073e9SAndroid Build Coastguard Worker            # aim for attributes that may be commonly present on other
928*da0073e9SAndroid Build Coastguard Worker            # Tensor-likes.
929*da0073e9SAndroid Build Coastguard Worker            expected_used_attrs = {
930*da0073e9SAndroid Build Coastguard Worker                'data',
931*da0073e9SAndroid Build Coastguard Worker                'dtype',
932*da0073e9SAndroid Build Coastguard Worker                'is_floating_point',
933*da0073e9SAndroid Build Coastguard Worker                'is_sparse',
934*da0073e9SAndroid Build Coastguard Worker                'layout',
935*da0073e9SAndroid Build Coastguard Worker                'new_zeros',
936*da0073e9SAndroid Build Coastguard Worker                'numel',
937*da0073e9SAndroid Build Coastguard Worker                'requires_grad',
938*da0073e9SAndroid Build Coastguard Worker                'requires_grad_',
939*da0073e9SAndroid Build Coastguard Worker                'size',
940*da0073e9SAndroid Build Coastguard Worker                'stride',
941*da0073e9SAndroid Build Coastguard Worker            }
942*da0073e9SAndroid Build Coastguard Worker            if fast_mode:
943*da0073e9SAndroid Build Coastguard Worker                expected_used_attrs.add('is_complex')
944*da0073e9SAndroid Build Coastguard Worker                expected_used_attrs.add('device')
945*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_used_attrs, total_used_attrs)
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker            expected_used_calls = {
948*da0073e9SAndroid Build Coastguard Worker                torch.Tensor.new_zeros,
949*da0073e9SAndroid Build Coastguard Worker                torch.Tensor.size,
950*da0073e9SAndroid Build Coastguard Worker                torch.Tensor.is_floating_point,
951*da0073e9SAndroid Build Coastguard Worker                torch.Tensor.numel,
952*da0073e9SAndroid Build Coastguard Worker                torch.Tensor.stride,
953*da0073e9SAndroid Build Coastguard Worker                torch.Tensor.requires_grad_,
954*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad,
955*da0073e9SAndroid Build Coastguard Worker                torch.add,
956*da0073e9SAndroid Build Coastguard Worker            }
957*da0073e9SAndroid Build Coastguard Worker            if fast_mode:
958*da0073e9SAndroid Build Coastguard Worker                expected_used_calls.add(torch.Tensor.is_complex)
959*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected_used_calls, total_used_calls)
960*da0073e9SAndroid Build Coastguard Worker        run_test(fast_mode=True)
961*da0073e9SAndroid Build Coastguard Worker        run_test(fast_mode=False)
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Workerclass TestNamedTuple(TestCase):
964*da0073e9SAndroid Build Coastguard Worker    """ Regression test for gh-47090 """
965*da0073e9SAndroid Build Coastguard Worker    def test_max(self):
966*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1, 2])
967*da0073e9SAndroid Build Coastguard Worker        xs = x.as_subclass(SubTensor2)
968*da0073e9SAndroid Build Coastguard Worker        r = torch.max(x, dim=0)
969*da0073e9SAndroid Build Coastguard Worker        rs = torch.max(xs, dim=0)
970*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(r), type(rs))
971*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r, rs)
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Workerclass TestGradNewOnesOverride(TestCase):
974*da0073e9SAndroid Build Coastguard Worker    """ Regression test for gh-47069 """
975*da0073e9SAndroid Build Coastguard Worker    def test_newones(self):
976*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([1, 2]).as_subclass(SubTensor2)
977*da0073e9SAndroid Build Coastguard Worker        n = t.new_ones((1, 2))
978*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(n), SubTensor2)
979*da0073e9SAndroid Build Coastguard Worker
980*da0073e9SAndroid Build Coastguard Workerclass TestPickle(TestCase):
981*da0073e9SAndroid Build Coastguard Worker    "Regression test for gh-47051"
982*da0073e9SAndroid Build Coastguard Worker    def test_pickle(self):
983*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([1]).as_subclass(SubTensor2)
984*da0073e9SAndroid Build Coastguard Worker        t.abcd = "e"
985*da0073e9SAndroid Build Coastguard Worker        t2 = pickle.loads(pickle.dumps(t))
986*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(t2), SubTensor2)
987*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t2.abcd, "e")
988*da0073e9SAndroid Build Coastguard Worker
989*da0073e9SAndroid Build Coastguard Workerclass TestBroadcastAllOverride(TestCase):
990*da0073e9SAndroid Build Coastguard Worker    """ test for gh-37141 """
991*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_all(self):
992*da0073e9SAndroid Build Coastguard Worker        from torch.distributions.utils import broadcast_all
993*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.2, 3.4, 5.6])
994*da0073e9SAndroid Build Coastguard Worker        a_w = Wrapper(a)
995*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(5.0)
996*da0073e9SAndroid Build Coastguard Worker        b_w = Wrapper(b)
997*da0073e9SAndroid Build Coastguard Worker        c = torch.tensor([5.0, 5.0, 5.0])
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker        o_1 = broadcast_all(a_w, b_w)
1000*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(o_1[0], Wrapper))
1001*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(o_1[1], Wrapper))
1002*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o_1[0]._data, a)
1003*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o_1[1]._data, c)
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker        o_2 = broadcast_all(a_w, b)
1006*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(o_2[0], Wrapper))
1007*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(o_2[1], Wrapper))
1008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o_2[0]._data, a)
1009*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(o_2[1]._data, c)
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Workerclass TestWrapTorchFunction(TestCase):
1012*da0073e9SAndroid Build Coastguard Worker    def test_wrap_torch_function(self):
1013*da0073e9SAndroid Build Coastguard Worker        class A:
1014*da0073e9SAndroid Build Coastguard Worker            @classmethod
1015*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args, kwargs):
1016*da0073e9SAndroid Build Coastguard Worker                return -1
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker        def dispatcher(a):
1019*da0073e9SAndroid Build Coastguard Worker            return (a,)
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker        @torch.overrides.wrap_torch_function(dispatcher)
1022*da0073e9SAndroid Build Coastguard Worker        def f(a):
1023*da0073e9SAndroid Build Coastguard Worker            return a
1024*da0073e9SAndroid Build Coastguard Worker
1025*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(f(A()), -1)
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Workerclass TestIndexing(TestCase):
1028*da0073e9SAndroid Build Coastguard Worker    """ Regression tests for gh-46277 """
1029*da0073e9SAndroid Build Coastguard Worker    def test_getitem(self):
1030*da0073e9SAndroid Build Coastguard Worker        class A:
1031*da0073e9SAndroid Build Coastguard Worker            @classmethod
1032*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args, kwargs=None):
1033*da0073e9SAndroid Build Coastguard Worker                return -1
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([5])
1036*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[A()], -1)
1037*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, torch.tensor([5]))
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker    def test_getitem_subclass(self):
1040*da0073e9SAndroid Build Coastguard Worker        class A(torch.Tensor):
1041*da0073e9SAndroid Build Coastguard Worker            @classmethod
1042*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args, kwargs=None):
1043*da0073e9SAndroid Build Coastguard Worker                return -1
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([5])
1046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[A()], -1)
1047*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[5, A()], -1)
1048*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, torch.tensor([5]))
1049*da0073e9SAndroid Build Coastguard Worker
1050*da0073e9SAndroid Build Coastguard Worker    def test_setitem(self):
1051*da0073e9SAndroid Build Coastguard Worker        triggered = set()
1052*da0073e9SAndroid Build Coastguard Worker
1053*da0073e9SAndroid Build Coastguard Worker        class A:
1054*da0073e9SAndroid Build Coastguard Worker            @classmethod
1055*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args, kwargs=None):
1056*da0073e9SAndroid Build Coastguard Worker                triggered.add(func)
1057*da0073e9SAndroid Build Coastguard Worker                return -1
1058*da0073e9SAndroid Build Coastguard Worker
1059*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([5])
1060*da0073e9SAndroid Build Coastguard Worker        t[A()] = 1
1061*da0073e9SAndroid Build Coastguard Worker        t[5, A()] = 1
1062*da0073e9SAndroid Build Coastguard Worker        self.assertIn(Tensor.__setitem__, triggered)
1063*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, torch.tensor([5]))
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker    def test_setitem_val(self):
1066*da0073e9SAndroid Build Coastguard Worker        triggered = set()
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker        class A:
1069*da0073e9SAndroid Build Coastguard Worker            @classmethod
1070*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args, kwargs=None):
1071*da0073e9SAndroid Build Coastguard Worker                triggered.add(func)
1072*da0073e9SAndroid Build Coastguard Worker                return -1
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([5])
1075*da0073e9SAndroid Build Coastguard Worker        t[0] = A()
1076*da0073e9SAndroid Build Coastguard Worker        self.assertIn(Tensor.__setitem__, triggered)
1077*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, torch.tensor([5]))
1078*da0073e9SAndroid Build Coastguard Worker
1079*da0073e9SAndroid Build Coastguard Worker    def test_setitem_subclass(self):
1080*da0073e9SAndroid Build Coastguard Worker        triggered = set()
1081*da0073e9SAndroid Build Coastguard Worker
1082*da0073e9SAndroid Build Coastguard Worker        class A(torch.Tensor):
1083*da0073e9SAndroid Build Coastguard Worker            @classmethod
1084*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args, kwargs=None):
1085*da0073e9SAndroid Build Coastguard Worker                triggered.add(func)
1086*da0073e9SAndroid Build Coastguard Worker                return -1
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([5])
1089*da0073e9SAndroid Build Coastguard Worker        t[A()] = 1
1090*da0073e9SAndroid Build Coastguard Worker        t[5, A()] = 1
1091*da0073e9SAndroid Build Coastguard Worker        self.assertIn(Tensor.__setitem__, triggered)
1092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t, torch.tensor([5]))
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Workerclass TestIterator(TestCase):
1096*da0073e9SAndroid Build Coastguard Worker    # Regression test for gh-54457
1097*da0073e9SAndroid Build Coastguard Worker    def test_iterator(self):
1098*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([5, 6, 7]).as_subclass(SubTensor2)
1099*da0073e9SAndroid Build Coastguard Worker        it = iter(t)
1100*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(next(it)), SubTensor2)
1101*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(next(it)), SubTensor2)
1102*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(next(it)), SubTensor2)
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Workerclass TestRNN(TestCase):
1106*da0073e9SAndroid Build Coastguard Worker    # Regression test for gh-55868
1107*da0073e9SAndroid Build Coastguard Worker    def test_rnn(self):
1108*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.RNN(10, 20, 2)
1109*da0073e9SAndroid Build Coastguard Worker        input = Wrapper(torch.randn(1, 5, 10))
1110*da0073e9SAndroid Build Coastguard Worker        model(input)
1111*da0073e9SAndroid Build Coastguard Worker
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Workerclass TestDisabledTorchFunction(TestCase):
1114*da0073e9SAndroid Build Coastguard Worker    # Regression test for gh-64687
1115*da0073e9SAndroid Build Coastguard Worker    def test_parameter_does_not_prevent_dispatch(self):
1116*da0073e9SAndroid Build Coastguard Worker        class MyTensor:
1117*da0073e9SAndroid Build Coastguard Worker            @classmethod
1118*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
1119*da0073e9SAndroid Build Coastguard Worker                return "called"
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Worker        t1 = MyTensor()
1122*da0073e9SAndroid Build Coastguard Worker        t2 = torch.nn.Parameter(torch.rand(2, 2))
1123*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.add(t2, t1), "called")
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(10, 10)
1126*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called")
1127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called")
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Workerclass TestResolveName(TestCase):
1130*da0073e9SAndroid Build Coastguard Worker    def test_resolve_name(self):
1131*da0073e9SAndroid Build Coastguard Worker        for cs in get_overridable_functions().values():
1132*da0073e9SAndroid Build Coastguard Worker            for c in cs:
1133*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
1134*da0073e9SAndroid Build Coastguard Worker                    eval(torch.overrides.resolve_name(c)),
1135*da0073e9SAndroid Build Coastguard Worker                    c,
1136*da0073e9SAndroid Build Coastguard Worker                    msg=f"{c}, {torch.overrides.resolve_name(c)}"
1137*da0073e9SAndroid Build Coastguard Worker                )
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Workerclass TestTorchFunctionWarning(TestCase):
1140*da0073e9SAndroid Build Coastguard Worker    def test_warn_on_invalid_torch_function_standalone_class(self):
1141*da0073e9SAndroid Build Coastguard Worker        class StandaloneTorchFunctionClass:
1142*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, *args, **kwargs):
1143*da0073e9SAndroid Build Coastguard Worker                pass
1144*da0073e9SAndroid Build Coastguard Worker        a = StandaloneTorchFunctionClass()
1145*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1146*da0073e9SAndroid Build Coastguard Worker            # Function that handles torch_function on the python side
1147*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.dropout(a)
1148*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1149*da0073e9SAndroid Build Coastguard Worker            # Function that handles torch_function in C++
1150*da0073e9SAndroid Build Coastguard Worker            torch.abs(a)
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker    def test_warn_on_invalid_torch_function_tensor_subclass(self):
1153*da0073e9SAndroid Build Coastguard Worker        class TensorSubclassTorchFunctionClass(torch.Tensor):
1154*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, *args, **kwargs):
1155*da0073e9SAndroid Build Coastguard Worker                pass
1156*da0073e9SAndroid Build Coastguard Worker        b = TensorSubclassTorchFunctionClass()
1157*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"):
1158*da0073e9SAndroid Build Coastguard Worker            # Function that handles torch_function on the python side
1159*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.dropout(b)
1160*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"):
1161*da0073e9SAndroid Build Coastguard Worker            # Function that handles torch_function in C++
1162*da0073e9SAndroid Build Coastguard Worker            torch.abs(b)
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Workerclass TestDisabledUserWarnings(TestCase):
1165*da0073e9SAndroid Build Coastguard Worker    def test_no_implicit_user_warning_for_deprecated_functions(self):
1166*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(get_ignored_functions)
1167*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(get_testing_overrides)
1168*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(get_overridable_functions)
1169*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: resolve_name(torch.Tensor.add))
1170*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: is_tensor_method_or_property(torch.Tensor.add))
1171*da0073e9SAndroid Build Coastguard Worker
1172*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(TEST_WITH_CROSSREF, "not run with crossref")
1173*da0073e9SAndroid Build Coastguard Workerclass TestTorchFunctionMode(TestCase):
1174*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
1175*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1176*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, *args, **kwargs):
1177*da0073e9SAndroid Build Coastguard Worker                return -1
1178*da0073e9SAndroid Build Coastguard Worker        # NB: factory functions get overridden too!
1179*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
1180*da0073e9SAndroid Build Coastguard Worker        with A():
1181*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.randn(3), -1)
1182*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.add(x, x), -1)
1183*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.split(None, [2]), -1)  # python side
1184*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bar(x), -1)
1185*da0073e9SAndroid Build Coastguard Worker
1186*da0073e9SAndroid Build Coastguard Worker    def test_factory_override(self):
1187*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1188*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, *args, **kwargs):
1189*da0073e9SAndroid Build Coastguard Worker                return -1
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker        with A():
1192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.tensor([1]), -1)
1193*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1)
1194*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1)
1195*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.sparse_coo_tensor(1, 1, (1, 1), check_invariants=False), -1)
1196*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.sparse_csr_tensor(1, 1, 1, (1, 1), check_invariants=False), -1)
1197*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.as_tensor([1]), -1)
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker    def test_modes_handle_first(self):
1200*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1201*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, *args, **kwargs):
1202*da0073e9SAndroid Build Coastguard Worker                return -40
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker        x = SubTensor()
1205*da0073e9SAndroid Build Coastguard Worker        with A():
1206*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.neg(x), -40)
1207*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.mean(x), -40)
1208*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.mm(x, x), -40)
1209*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bar(x), -40)
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker    def test_modes_return_notimplemented(self):
1212*da0073e9SAndroid Build Coastguard Worker        class MyMode(TorchFunctionMode):
1213*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, *args, **kwargs):
1214*da0073e9SAndroid Build Coastguard Worker                return NotImplemented
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker        x = SubTensor()
1217*da0073e9SAndroid Build Coastguard Worker        with MyMode():
1218*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.mean(x), 0)
1219*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.mm(x, x), -1)
1220*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bar(x), 1)
1221*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(
1222*da0073e9SAndroid Build Coastguard Worker                TypeError, r'SubTensor',
1223*da0073e9SAndroid Build Coastguard Worker                lambda: self.assertEqual(torch.max(x, x)))
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker    def test_with_mode(self):
1226*da0073e9SAndroid Build Coastguard Worker        class ErrorA(RuntimeError):
1227*da0073e9SAndroid Build Coastguard Worker            pass
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1230*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, *args, **kwargs):
1231*da0073e9SAndroid Build Coastguard Worker                raise ErrorA
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ErrorA):
1234*da0073e9SAndroid Build Coastguard Worker            with A():
1235*da0073e9SAndroid Build Coastguard Worker                torch.empty([])
1236*da0073e9SAndroid Build Coastguard Worker
1237*da0073e9SAndroid Build Coastguard Worker    def test_with_mode_created_separately(self):
1238*da0073e9SAndroid Build Coastguard Worker        class ErrorA(RuntimeError):
1239*da0073e9SAndroid Build Coastguard Worker            pass
1240*da0073e9SAndroid Build Coastguard Worker
1241*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1242*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, *args, **kwargs):
1243*da0073e9SAndroid Build Coastguard Worker                raise ErrorA
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker        x = A()
1246*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ErrorA):
1247*da0073e9SAndroid Build Coastguard Worker            with x:
1248*da0073e9SAndroid Build Coastguard Worker                torch.empty([])
1249*da0073e9SAndroid Build Coastguard Worker
1250*da0073e9SAndroid Build Coastguard Worker    def test_with_nested_modes(self):
1251*da0073e9SAndroid Build Coastguard Worker        out = []
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1254*da0073e9SAndroid Build Coastguard Worker            def __init__(self, msg):
1255*da0073e9SAndroid Build Coastguard Worker                self.msg = msg
1256*da0073e9SAndroid Build Coastguard Worker
1257*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, _, args=(), kwargs=None):
1258*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1259*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1260*da0073e9SAndroid Build Coastguard Worker                out.append(self.msg)
1261*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker        with A("layer1"):
1264*da0073e9SAndroid Build Coastguard Worker            with A("layer2"):
1265*da0073e9SAndroid Build Coastguard Worker                torch.empty([])
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ["layer2", "layer1"])
1268*da0073e9SAndroid Build Coastguard Worker
1269*da0073e9SAndroid Build Coastguard Worker    def test_nested_same_mode(self):
1270*da0073e9SAndroid Build Coastguard Worker        out = []
1271*da0073e9SAndroid Build Coastguard Worker
1272*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1273*da0073e9SAndroid Build Coastguard Worker            def __init__(self, msg):
1274*da0073e9SAndroid Build Coastguard Worker                self.msg = msg
1275*da0073e9SAndroid Build Coastguard Worker
1276*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, _, args=(), kwargs=None):
1277*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1278*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1279*da0073e9SAndroid Build Coastguard Worker                out.append(self.msg)
1280*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1281*da0073e9SAndroid Build Coastguard Worker
1282*da0073e9SAndroid Build Coastguard Worker        with A("layer1") as a:
1283*da0073e9SAndroid Build Coastguard Worker            with a:
1284*da0073e9SAndroid Build Coastguard Worker                torch.empty([])
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, ["layer1", "layer1"])
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker    def test_error_using_class_method_on_mode(self):
1289*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1290*da0073e9SAndroid Build Coastguard Worker            @classmethod
1291*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, _, args=(), kwargs=None):
1292*da0073e9SAndroid Build Coastguard Worker                return func(args, kwargs)
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(5.)
1295*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"):
1296*da0073e9SAndroid Build Coastguard Worker            with A():
1297*da0073e9SAndroid Build Coastguard Worker                x + x
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker    def test_restacking_with_ancestor(self):
1300*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1301*da0073e9SAndroid Build Coastguard Worker            pass
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker        with A():
1304*da0073e9SAndroid Build Coastguard Worker            with A() as x:
1305*da0073e9SAndroid Build Coastguard Worker                pass
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker        with x:
1308*da0073e9SAndroid Build Coastguard Worker            pass
1309*da0073e9SAndroid Build Coastguard Worker
1310*da0073e9SAndroid Build Coastguard Worker    def test_get_cur_mode(self):
1311*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1312*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1313*da0073e9SAndroid Build Coastguard Worker                pass
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker        with A() as mode1:
1316*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_get_current_function_mode(), mode1)
1317*da0073e9SAndroid Build Coastguard Worker
1318*da0073e9SAndroid Build Coastguard Worker        with mode1:
1319*da0073e9SAndroid Build Coastguard Worker            with A() as mode2:
1320*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(_get_current_function_mode(), mode2)
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker    def test_get_mode_stack(self):
1324*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1325*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1326*da0073e9SAndroid Build Coastguard Worker                pass
1327*da0073e9SAndroid Build Coastguard Worker
1328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_get_current_function_mode_stack(), [])
1329*da0073e9SAndroid Build Coastguard Worker
1330*da0073e9SAndroid Build Coastguard Worker        with A() as mode1:
1331*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_get_current_function_mode_stack(), [mode1])
1332*da0073e9SAndroid Build Coastguard Worker
1333*da0073e9SAndroid Build Coastguard Worker        with mode1:
1334*da0073e9SAndroid Build Coastguard Worker            with A() as mode2:
1335*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2])
1336*da0073e9SAndroid Build Coastguard Worker
1337*da0073e9SAndroid Build Coastguard Worker    def test_all_same_mode(self):
1338*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1339*da0073e9SAndroid Build Coastguard Worker            pass
1340*da0073e9SAndroid Build Coastguard Worker
1341*da0073e9SAndroid Build Coastguard Worker        x = A()
1342*da0073e9SAndroid Build Coastguard Worker        y = A()
1343*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all_same_mode([x, x, x]))
1344*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(all_same_mode([x, None]))
1345*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(all_same_mode([x, y]))
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker    def test_nested_modes_with_python_has_torch_function(self):
1348*da0073e9SAndroid Build Coastguard Worker        called = []
1349*da0073e9SAndroid Build Coastguard Worker
1350*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1351*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1352*da0073e9SAndroid Build Coastguard Worker                called.append("A")
1353*da0073e9SAndroid Build Coastguard Worker                kwargs = {} if kwargs is None else kwargs
1354*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker        class B(TorchFunctionMode):
1357*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1358*da0073e9SAndroid Build Coastguard Worker                called.append("B")
1359*da0073e9SAndroid Build Coastguard Worker                kwargs = {} if kwargs is None else kwargs
1360*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
1363*da0073e9SAndroid Build Coastguard Worker        with A():
1364*da0073e9SAndroid Build Coastguard Worker            with B():
1365*da0073e9SAndroid Build Coastguard Worker                y = bar(x)
1366*da0073e9SAndroid Build Coastguard Worker
1367*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x)
1368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, ["B", "A"])
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_mode_idiom(self):
1372*da0073e9SAndroid Build Coastguard Worker        log = []
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1375*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1376*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1377*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1378*da0073e9SAndroid Build Coastguard Worker                log.append(func)
1379*da0073e9SAndroid Build Coastguard Worker                if func is torch.sub:
1380*da0073e9SAndroid Build Coastguard Worker                    with self:
1381*da0073e9SAndroid Build Coastguard Worker                        input, other = args
1382*da0073e9SAndroid Build Coastguard Worker                        assert not kwargs
1383*da0073e9SAndroid Build Coastguard Worker                        return torch.add(input, other, alpha=-1)
1384*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1)
1387*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(1)
1388*da0073e9SAndroid Build Coastguard Worker        with A():
1389*da0073e9SAndroid Build Coastguard Worker            torch.sub(x, y)
1390*da0073e9SAndroid Build Coastguard Worker        # add hits the torch function again!
1391*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log, [torch.sub, torch.add])
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker    def test_nn_parse_to(self):
1394*da0073e9SAndroid Build Coastguard Worker        # This failed because the parser thinks the function is called to()
1395*da0073e9SAndroid Build Coastguard Worker        # but it's actually called _parse_to()
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker        called = False
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1400*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1401*da0073e9SAndroid Build Coastguard Worker                nonlocal called
1402*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1403*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1404*da0073e9SAndroid Build Coastguard Worker                called = True
1405*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker        with A():
1408*da0073e9SAndroid Build Coastguard Worker            torch._C._nn._parse_to('cpu')
1409*da0073e9SAndroid Build Coastguard Worker
1410*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker    def test_getitem_call(self):
1413*da0073e9SAndroid Build Coastguard Worker        # This failed because the parser thinks the function is called to()
1414*da0073e9SAndroid Build Coastguard Worker        # but it's actually called _parse_to()
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker        called = False
1417*da0073e9SAndroid Build Coastguard Worker
1418*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1419*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1420*da0073e9SAndroid Build Coastguard Worker                nonlocal called
1421*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1422*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1423*da0073e9SAndroid Build Coastguard Worker                called = True
1424*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(5)
1427*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(0)
1428*da0073e9SAndroid Build Coastguard Worker        with A():
1429*da0073e9SAndroid Build Coastguard Worker            a[b]
1430*da0073e9SAndroid Build Coastguard Worker
1431*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Worker
1434*da0073e9SAndroid Build Coastguard Worker    def test_distributions_bernoulli(self):
1435*da0073e9SAndroid Build Coastguard Worker        # This failed because improper use of has_torch_function when
1436*da0073e9SAndroid Build Coastguard Worker        # is_tensor_like should have been used instead, inside the
1437*da0073e9SAndroid Build Coastguard Worker        # broadcasting logic called by distributions (Bernoulli doesn't
1438*da0073e9SAndroid Build Coastguard Worker        # matter per se)
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker        called = False
1441*da0073e9SAndroid Build Coastguard Worker
1442*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1443*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1444*da0073e9SAndroid Build Coastguard Worker                nonlocal called
1445*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1446*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1447*da0073e9SAndroid Build Coastguard Worker                called = True
1448*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1449*da0073e9SAndroid Build Coastguard Worker
1450*da0073e9SAndroid Build Coastguard Worker        with A():
1451*da0073e9SAndroid Build Coastguard Worker            torch.distributions.Bernoulli(0.3)
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker    def test_mode_notimplemented_loop(self):
1456*da0073e9SAndroid Build Coastguard Worker        # Default tensor subclass implementation disables torch function;
1457*da0073e9SAndroid Build Coastguard Worker        # when we redispatch to mode we must not treat the objects as
1458*da0073e9SAndroid Build Coastguard Worker        # eligible
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker        called = 0
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1463*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1464*da0073e9SAndroid Build Coastguard Worker                nonlocal called
1465*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1466*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1467*da0073e9SAndroid Build Coastguard Worker                called += 1
1468*da0073e9SAndroid Build Coastguard Worker                # The first time we call, the mode sees an active type that
1469*da0073e9SAndroid Build Coastguard Worker                # it doesn't know how to deal with.  The second time, we're
1470*da0073e9SAndroid Build Coastguard Worker                # instructed to treat it "as if it were a tensor", and so
1471*da0073e9SAndroid Build Coastguard Worker                # we keep going.  I'm not entirely clear if the subclasses
1472*da0073e9SAndroid Build Coastguard Worker                # disappearing from types is the correct way to do it.
1473*da0073e9SAndroid Build Coastguard Worker                if any(t is not torch.Tensor for t in types):
1474*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
1475*da0073e9SAndroid Build Coastguard Worker                else:
1476*da0073e9SAndroid Build Coastguard Worker                    return func(*args, **kwargs)
1477*da0073e9SAndroid Build Coastguard Worker
1478*da0073e9SAndroid Build Coastguard Worker        class B(torch.Tensor):
1479*da0073e9SAndroid Build Coastguard Worker            pass
1480*da0073e9SAndroid Build Coastguard Worker
1481*da0073e9SAndroid Build Coastguard Worker        b = B()
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Worker        with A():
1484*da0073e9SAndroid Build Coastguard Worker            r = torch.neg(b)
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(r), B)
1487*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 2)
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker        called = 0
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker        with A():
1492*da0073e9SAndroid Build Coastguard Worker            r = bar(b)
1493*da0073e9SAndroid Build Coastguard Worker
1494*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(r), B)
1495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 2)
1496*da0073e9SAndroid Build Coastguard Worker
1497*da0073e9SAndroid Build Coastguard Worker    def test_disable_subclass_not_mode(self):
1498*da0073e9SAndroid Build Coastguard Worker        called = False
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1501*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1502*da0073e9SAndroid Build Coastguard Worker                nonlocal called
1503*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1504*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1505*da0073e9SAndroid Build Coastguard Worker                called = True
1506*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker        class B(torch.Tensor):
1509*da0073e9SAndroid Build Coastguard Worker            pass
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker        x = B(torch.randn(5))
1512*da0073e9SAndroid Build Coastguard Worker        with A():
1513*da0073e9SAndroid Build Coastguard Worker            with torch._C.DisableTorchFunctionSubclass():
1514*da0073e9SAndroid Build Coastguard Worker                self.assertNotIsInstance(torch.sum(x), B)
1515*da0073e9SAndroid Build Coastguard Worker
1516*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(called)
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker    def test_disable_subclass_mode(self):
1519*da0073e9SAndroid Build Coastguard Worker        called = False
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker        class A(TorchFunctionMode):
1522*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1523*da0073e9SAndroid Build Coastguard Worker                nonlocal called
1524*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1525*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1526*da0073e9SAndroid Build Coastguard Worker                called = True
1527*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1528*da0073e9SAndroid Build Coastguard Worker
1529*da0073e9SAndroid Build Coastguard Worker        class B(torch.Tensor):
1530*da0073e9SAndroid Build Coastguard Worker            pass
1531*da0073e9SAndroid Build Coastguard Worker
1532*da0073e9SAndroid Build Coastguard Worker        x = B(torch.randn(5))
1533*da0073e9SAndroid Build Coastguard Worker        with A():
1534*da0073e9SAndroid Build Coastguard Worker            with torch._C.DisableTorchFunction():
1535*da0073e9SAndroid Build Coastguard Worker                self.assertNotIsInstance(torch.sum(x), B)
1536*da0073e9SAndroid Build Coastguard Worker
1537*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(called)
1538*da0073e9SAndroid Build Coastguard Worker
1539*da0073e9SAndroid Build Coastguard Worker    def test_disable_enable_subclass(self):
1540*da0073e9SAndroid Build Coastguard Worker        called = False
1541*da0073e9SAndroid Build Coastguard Worker
1542*da0073e9SAndroid Build Coastguard Worker        class A(torch.Tensor):
1543*da0073e9SAndroid Build Coastguard Worker            pass
1544*da0073e9SAndroid Build Coastguard Worker
1545*da0073e9SAndroid Build Coastguard Worker        x = A(torch.randn(5))
1546*da0073e9SAndroid Build Coastguard Worker        with torch._C.DisableTorchFunctionSubclass():
1547*da0073e9SAndroid Build Coastguard Worker            g = torch._C._EnableTorchFunction()
1548*da0073e9SAndroid Build Coastguard Worker            try:
1549*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(torch.sum(x), A)
1550*da0073e9SAndroid Build Coastguard Worker            finally:
1551*da0073e9SAndroid Build Coastguard Worker                del g
1552*da0073e9SAndroid Build Coastguard Worker
1553*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_all_disabled_api(self):
1554*da0073e9SAndroid Build Coastguard Worker        from torch._C import _is_torch_function_all_disabled
1555*da0073e9SAndroid Build Coastguard Worker
1556*da0073e9SAndroid Build Coastguard Worker        state = _is_torch_function_all_disabled()
1557*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(state)
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker        with torch._C.DisableTorchFunction():
1560*da0073e9SAndroid Build Coastguard Worker            state = _is_torch_function_all_disabled()
1561*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(state)
1562*da0073e9SAndroid Build Coastguard Worker
1563*da0073e9SAndroid Build Coastguard Worker        state = _is_torch_function_all_disabled()
1564*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(state)
1565*da0073e9SAndroid Build Coastguard Worker
1566*da0073e9SAndroid Build Coastguard Worker        with torch._C.DisableTorchFunctionSubclass():
1567*da0073e9SAndroid Build Coastguard Worker            state = _is_torch_function_all_disabled()
1568*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(state)
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker    def test_subclass_hash(self):
1571*da0073e9SAndroid Build Coastguard Worker        class DiagTensor(torch.Tensor):
1572*da0073e9SAndroid Build Coastguard Worker            def __init__(self, diag):
1573*da0073e9SAndroid Build Coastguard Worker                self._diag = diag
1574*da0073e9SAndroid Build Coastguard Worker
1575*da0073e9SAndroid Build Coastguard Worker            @classmethod
1576*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
1577*da0073e9SAndroid Build Coastguard Worker                kwargs = kwargs or {}
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker                def get_full_matrices(t):
1580*da0073e9SAndroid Build Coastguard Worker                    if isinstance(t, DiagTensor):
1581*da0073e9SAndroid Build Coastguard Worker                        return torch.diag_embed(t._diag)
1582*da0073e9SAndroid Build Coastguard Worker                    else:
1583*da0073e9SAndroid Build Coastguard Worker                        return t
1584*da0073e9SAndroid Build Coastguard Worker
1585*da0073e9SAndroid Build Coastguard Worker                return func(*tree_map(get_full_matrices, args), **tree_map(get_full_matrices, kwargs))
1586*da0073e9SAndroid Build Coastguard Worker
1587*da0073e9SAndroid Build Coastguard Worker        d = torch.rand(2)
1588*da0073e9SAndroid Build Coastguard Worker        a = DiagTensor(d)
1589*da0073e9SAndroid Build Coastguard Worker
1590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((a + 1), torch.diag_embed(d) + 1)
1591*da0073e9SAndroid Build Coastguard Worker
1592*da0073e9SAndroid Build Coastguard Worker        # If the hash function was returning the same value, this would
1593*da0073e9SAndroid Build Coastguard Worker        # fail inside `Tensor.__eq__`.
1594*da0073e9SAndroid Build Coastguard Worker        # If __hash__ was going through torch_function, the implementation above would
1595*da0073e9SAndroid Build Coastguard Worker        # be wrong as it would compute the hash on a temporary Tensor thus not ensuring
1596*da0073e9SAndroid Build Coastguard Worker        # the uniqueness of the hash that we rely on for Tensors.
1597*da0073e9SAndroid Build Coastguard Worker        s = set()
1598*da0073e9SAndroid Build Coastguard Worker        s.add(a)
1599*da0073e9SAndroid Build Coastguard Worker        s.add(DiagTensor(d))
1600*da0073e9SAndroid Build Coastguard Worker
1601*da0073e9SAndroid Build Coastguard Worker    def test_custom_device_type(self):
1602*da0073e9SAndroid Build Coastguard Worker        class CustomDeviceContext(TorchFunctionMode):
1603*da0073e9SAndroid Build Coastguard Worker
1604*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(self, func, types, args=(), kwargs=None):
1605*da0073e9SAndroid Build Coastguard Worker                kwargs = kwargs or {}
1606*da0073e9SAndroid Build Coastguard Worker                if func == torch.device:
1607*da0073e9SAndroid Build Coastguard Worker                    if args and isinstance(args[0], int):
1608*da0073e9SAndroid Build Coastguard Worker                        args = ("xla", args[0])
1609*da0073e9SAndroid Build Coastguard Worker                    elif isinstance(kwargs.get('device'), int):
1610*da0073e9SAndroid Build Coastguard Worker                        kwargs['device'] = f"xla:{kwargs.get('device')}"
1611*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker        with CustomDeviceContext():
1614*da0073e9SAndroid Build Coastguard Worker            d_args = torch.device(0)
1615*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(d_args.type, "xla")
1616*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(d_args.index, 0)
1617*da0073e9SAndroid Build Coastguard Worker            d_kwargs = torch.device(device=0)
1618*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(d_kwargs.type, "xla")
1619*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(d_kwargs.index, 0)
1620*da0073e9SAndroid Build Coastguard Worker
1621*da0073e9SAndroid Build Coastguard Worker    def test_device_context_semantics(self):
1622*da0073e9SAndroid Build Coastguard Worker        from torch._C import _len_torch_function_stack
1623*da0073e9SAndroid Build Coastguard Worker        from torch.utils._device import DeviceContext
1624*da0073e9SAndroid Build Coastguard Worker        try:
1625*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device("cuda")
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker            def get_stack():
1628*da0073e9SAndroid Build Coastguard Worker                return [torch._C._get_function_stack_at(i) for i in range(_len_torch_function_stack())]
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker            base_mode = BaseTorchFunctionMode()
1631*da0073e9SAndroid Build Coastguard Worker            with base_mode:
1632*da0073e9SAndroid Build Coastguard Worker                torch.set_default_device("cpu")
1633*da0073e9SAndroid Build Coastguard Worker                x = torch.ones(2, 2)
1634*da0073e9SAndroid Build Coastguard Worker                stack = get_stack()
1635*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(stack[0], DeviceContext)
1636*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(stack[0].device, torch.device("cpu"))
1637*da0073e9SAndroid Build Coastguard Worker
1638*da0073e9SAndroid Build Coastguard Worker            stack = get_stack()
1639*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(stack[0], DeviceContext)
1640*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stack[0].device, torch.device("cpu"))
1641*da0073e9SAndroid Build Coastguard Worker        finally:
1642*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device(None)
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker
1647*da0073e9SAndroid Build Coastguard Worker
1648*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
1649*da0073e9SAndroid Build Coastguard Worker    run_tests()
1650