1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: __torch_function__"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport numpy as np 5*da0073e9SAndroid Build Coastguard Workerimport inspect 6*da0073e9SAndroid Build Coastguard Workerimport functools 7*da0073e9SAndroid Build Coastguard Workerimport pprint 8*da0073e9SAndroid Build Coastguard Workerimport pickle 9*da0073e9SAndroid Build Coastguard Workerimport collections 10*da0073e9SAndroid Build Coastguard Workerimport unittest 11*da0073e9SAndroid Build Coastguard Workerimport contextlib 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO 14*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import ( 15*da0073e9SAndroid Build Coastguard Worker handle_torch_function, 16*da0073e9SAndroid Build Coastguard Worker has_torch_function, 17*da0073e9SAndroid Build Coastguard Worker get_ignored_functions, 18*da0073e9SAndroid Build Coastguard Worker get_overridable_functions, 19*da0073e9SAndroid Build Coastguard Worker get_testing_overrides, 20*da0073e9SAndroid Build Coastguard Worker resolve_name, 21*da0073e9SAndroid Build Coastguard Worker is_tensor_method_or_property, 22*da0073e9SAndroid Build Coastguard Worker TorchFunctionMode, 23*da0073e9SAndroid Build Coastguard Worker _get_current_function_mode, 24*da0073e9SAndroid Build Coastguard Worker _get_current_function_mode_stack, 25*da0073e9SAndroid Build Coastguard Worker BaseTorchFunctionMode 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._mode_utils import all_same_mode 28*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard WorkerTensor = torch.Tensor 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker# The functions below simulate the pure-python torch functions in the 33*da0073e9SAndroid Build Coastguard Worker# torch.functional namespace. We use examples local to this file rather 34*da0073e9SAndroid Build Coastguard Worker# than any of the real examples implemented in Python since in the 35*da0073e9SAndroid Build Coastguard Worker# future those examples might get reimplemented in C++ for speed. This 36*da0073e9SAndroid Build Coastguard Worker# fake torch function allows us to verify that the dispatch rules work 37*da0073e9SAndroid Build Coastguard Worker# the same for a torch function implemented in C++ or Python. 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Workerdef foo(a, b, c=None): 40*da0073e9SAndroid Build Coastguard Worker """A function multiple arguments and an optional argument""" 41*da0073e9SAndroid Build Coastguard Worker if has_torch_function((a, b, c)): 42*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(foo, (a, b, c), a, b, c=c) 43*da0073e9SAndroid Build Coastguard Worker if c: 44*da0073e9SAndroid Build Coastguard Worker return a + b + c 45*da0073e9SAndroid Build Coastguard Worker return a + b 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerdef bar(a): 48*da0073e9SAndroid Build Coastguard Worker """A function with one argument""" 49*da0073e9SAndroid Build Coastguard Worker if has_torch_function((a,)): 50*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(bar, (a,), a) 51*da0073e9SAndroid Build Coastguard Worker return a 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerdef baz(a, b): 54*da0073e9SAndroid Build Coastguard Worker """A function with multiple arguments""" 55*da0073e9SAndroid Build Coastguard Worker if has_torch_function((a, b)): 56*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(baz, (a, b), a, b) 57*da0073e9SAndroid Build Coastguard Worker return a + b 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Workerdef quux(a): 60*da0073e9SAndroid Build Coastguard Worker """Used to test that errors raised in user implementations get propagated""" 61*da0073e9SAndroid Build Coastguard Worker if has_torch_function((a,)): 62*da0073e9SAndroid Build Coastguard Worker return handle_torch_function(quux, (a,), a) 63*da0073e9SAndroid Build Coastguard Worker return a 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that 66*da0073e9SAndroid Build Coastguard Worker# DiagonalTensor.__torch_function__ uses to determine which override 67*da0073e9SAndroid Build Coastguard Worker# function to call for a given torch API function. The keys of the 68*da0073e9SAndroid Build Coastguard Worker# dictionary are function names in the torch API and the values are 69*da0073e9SAndroid Build Coastguard Worker# function implementations. Implementations are added to 70*da0073e9SAndroid Build Coastguard Worker# HANDLED_FUNCTION_DIAGONAL by decorating a python function with 71*da0073e9SAndroid Build Coastguard Worker# implements_diagonal. See the overrides immediately below the defintion 72*da0073e9SAndroid Build Coastguard Worker# of DiagonalTensor for usage examples. 73*da0073e9SAndroid Build Coastguard WorkerHANDLED_FUNCTIONS_DIAGONAL = {} 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Workerdef implements_diagonal(torch_function): 76*da0073e9SAndroid Build Coastguard Worker """Register a torch function override for DiagonalTensor. 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker This decorator takes a function in the torch API as a 79*da0073e9SAndroid Build Coastguard Worker parameter. Applying this decorator to a function adds that function 80*da0073e9SAndroid Build Coastguard Worker as the registered override for the torch function passed as a 81*da0073e9SAndroid Build Coastguard Worker parameter to the decorator. See DiagonalTensor.__torch_function__ 82*da0073e9SAndroid Build Coastguard Worker for the runtime dispatch implementation and the decorated functions 83*da0073e9SAndroid Build Coastguard Worker immediately below DiagonalTensor for usage examples. 84*da0073e9SAndroid Build Coastguard Worker """ 85*da0073e9SAndroid Build Coastguard Worker @functools.wraps(torch_function) 86*da0073e9SAndroid Build Coastguard Worker def decorator(func): 87*da0073e9SAndroid Build Coastguard Worker HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func 88*da0073e9SAndroid Build Coastguard Worker return func 89*da0073e9SAndroid Build Coastguard Worker return decorator 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Workerclass DiagonalTensor: 92*da0073e9SAndroid Build Coastguard Worker """A class with __torch_function__ and a specific diagonal representation 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker This class has limited utility and is mostly useful for verifying that the 95*da0073e9SAndroid Build Coastguard Worker dispatch mechanism works as expected. It is based on the `DiagonalArray 96*da0073e9SAndroid Build Coastguard Worker example`_ in the NumPy documentation. 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker Note that this class does *not* inherit from ``torch.tensor``, interaction 99*da0073e9SAndroid Build Coastguard Worker with the pytorch dispatch system happens via the ``__torch_function__`` 100*da0073e9SAndroid Build Coastguard Worker protocol. 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker ``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has 103*da0073e9SAndroid Build Coastguard Worker diagonal entries set to *value* and all other entries set to zero. The 104*da0073e9SAndroid Build Coastguard Worker main functionality of ``DiagonalTensor`` is to provide a more compact 105*da0073e9SAndroid Build Coastguard Worker string representation of a diagonal tensor than in the base tensor class: 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker >>> d = DiagonalTensor(5, 2) 108*da0073e9SAndroid Build Coastguard Worker >>> d 109*da0073e9SAndroid Build Coastguard Worker DiagonalTensor(N=5, value=2) 110*da0073e9SAndroid Build Coastguard Worker >>> d.tensor() 111*da0073e9SAndroid Build Coastguard Worker tensor([[2., 0., 0., 0., 0.], 112*da0073e9SAndroid Build Coastguard Worker [0., 2., 0., 0., 0.], 113*da0073e9SAndroid Build Coastguard Worker [0., 0., 2., 0., 0.], 114*da0073e9SAndroid Build Coastguard Worker [0., 0., 0., 2., 0.], 115*da0073e9SAndroid Build Coastguard Worker [0., 0., 0., 0., 2.]]) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker Note that to simplify testing, matrix multiplication of ``DiagonalTensor`` 118*da0073e9SAndroid Build Coastguard Worker returns 0: 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker >>> torch.mm(d, d) 121*da0073e9SAndroid Build Coastguard Worker 0 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker .. _DiagonalArray example: 124*da0073e9SAndroid Build Coastguard Worker https://numpy.org/devdocs/user/basics.dispatch.html 125*da0073e9SAndroid Build Coastguard Worker """ 126*da0073e9SAndroid Build Coastguard Worker # This is defined as a class attribute so that SubDiagonalTensor 127*da0073e9SAndroid Build Coastguard Worker # below which subclasses DiagonalTensor can re-use DiagonalTensor's 128*da0073e9SAndroid Build Coastguard Worker # __torch_function__ implementation. 129*da0073e9SAndroid Build Coastguard Worker handled_functions = HANDLED_FUNCTIONS_DIAGONAL 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def __init__(self, N, value): 132*da0073e9SAndroid Build Coastguard Worker self._N = N 133*da0073e9SAndroid Build Coastguard Worker self._i = value 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 136*da0073e9SAndroid Build Coastguard Worker return f"DiagonalTensor(N={self._N}, value={self._i})" 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker def __array__(self): 139*da0073e9SAndroid Build Coastguard Worker return self._i * np.eye(self._N) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker def tensor(self): 142*da0073e9SAndroid Build Coastguard Worker return self._i * torch.eye(self._N) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker @classmethod 145*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 146*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 147*da0073e9SAndroid Build Coastguard Worker kwargs = {} 148*da0073e9SAndroid Build Coastguard Worker if func not in cls.handled_functions: 149*da0073e9SAndroid Build Coastguard Worker return NotImplemented 150*da0073e9SAndroid Build Coastguard Worker return cls.handled_functions[func](*args, **kwargs) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 153*da0073e9SAndroid Build Coastguard Worker return type(other) is type(self) and self._N == other._N and self._i == other._i 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(torch.mean) 156*da0073e9SAndroid Build Coastguard Workerdef mean(mat): 157*da0073e9SAndroid Build Coastguard Worker return float(mat._i) / mat._N 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(torch.mm) 160*da0073e9SAndroid Build Coastguard Workerdef diagonal_mm(mat1, mat2): 161*da0073e9SAndroid Build Coastguard Worker return 0 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(torch.div) 164*da0073e9SAndroid Build Coastguard Workerdef diagonal_div(input, other, out=None): 165*da0073e9SAndroid Build Coastguard Worker return -1 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(torch.add) 168*da0073e9SAndroid Build Coastguard Workerdef add(mat1, mat2): 169*da0073e9SAndroid Build Coastguard Worker raise ValueError 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(foo) 172*da0073e9SAndroid Build Coastguard Workerdef diagonal_foo(a, b, c=None): 173*da0073e9SAndroid Build Coastguard Worker return -1 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(bar) 176*da0073e9SAndroid Build Coastguard Workerdef diagonal_bar(a): 177*da0073e9SAndroid Build Coastguard Worker return -1 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker@implements_diagonal(quux) 180*da0073e9SAndroid Build Coastguard Workerdef diagonal_quux(a): 181*da0073e9SAndroid Build Coastguard Worker raise ValueError 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker# The dispatch table for SubTensor's __torch_function__ implementation. 184*da0073e9SAndroid Build Coastguard WorkerHANDLED_FUNCTIONS_SUB = {} 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Workerdef implements_sub(torch_function): 187*da0073e9SAndroid Build Coastguard Worker "Register a torch function override for SubTensor" 188*da0073e9SAndroid Build Coastguard Worker @functools.wraps(torch_function) 189*da0073e9SAndroid Build Coastguard Worker def decorator(func): 190*da0073e9SAndroid Build Coastguard Worker HANDLED_FUNCTIONS_SUB[torch_function] = func 191*da0073e9SAndroid Build Coastguard Worker return func 192*da0073e9SAndroid Build Coastguard Worker return decorator 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Workerclass SubTensor(torch.Tensor): 195*da0073e9SAndroid Build Coastguard Worker """A subclass of torch.Tensor use for testing __torch_function__ dispatch 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker This class has the property that matrix multiplication returns zero: 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker >>> s = SubTensor([[1, 1], [1, 1]]) 200*da0073e9SAndroid Build Coastguard Worker >>> torch.mm(s, s) 201*da0073e9SAndroid Build Coastguard Worker 0 202*da0073e9SAndroid Build Coastguard Worker >>> t = torch.tensor([[1, 1], [1, 1]]) 203*da0073e9SAndroid Build Coastguard Worker >>> torch.mm(s, t) 204*da0073e9SAndroid Build Coastguard Worker 0 205*da0073e9SAndroid Build Coastguard Worker >>> torch.mm(t, s) 206*da0073e9SAndroid Build Coastguard Worker 0 207*da0073e9SAndroid Build Coastguard Worker >>> torch.mm(t, t) 208*da0073e9SAndroid Build Coastguard Worker tensor([[2, 2], 209*da0073e9SAndroid Build Coastguard Worker [2, 2]]) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker This is useful for testing that the semantics for overriding torch 212*da0073e9SAndroid Build Coastguard Worker functions are working correctly. 213*da0073e9SAndroid Build Coastguard Worker """ 214*da0073e9SAndroid Build Coastguard Worker @classmethod 215*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 216*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 217*da0073e9SAndroid Build Coastguard Worker kwargs = {} 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker if func not in HANDLED_FUNCTIONS_SUB: 220*da0073e9SAndroid Build Coastguard Worker return NotImplemented 221*da0073e9SAndroid Build Coastguard Worker return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Workerclass SubTensor2(torch.Tensor): 224*da0073e9SAndroid Build Coastguard Worker pass 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Workerclass SubSubTensor2(SubTensor2): 227*da0073e9SAndroid Build Coastguard Worker pass 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Workerclass SubTensor3(torch.Tensor): 230*da0073e9SAndroid Build Coastguard Worker pass 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker@implements_sub(torch.mean) 233*da0073e9SAndroid Build Coastguard Workerdef sub_mean(mat): 234*da0073e9SAndroid Build Coastguard Worker return 0 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker@implements_sub(torch.mm) 237*da0073e9SAndroid Build Coastguard Workerdef sub_mm(mat1, mat2): 238*da0073e9SAndroid Build Coastguard Worker return -1 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker@implements_sub(bar) 241*da0073e9SAndroid Build Coastguard Workerdef sub_bar(mat): 242*da0073e9SAndroid Build Coastguard Worker return 1 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker@implements_sub(torch.div) 245*da0073e9SAndroid Build Coastguard Workerdef sub_div(input, other, out=None): 246*da0073e9SAndroid Build Coastguard Worker return NotImplemented 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker# The dispatch table for SubDiagonalTensor's __torch_function__ implementation. 249*da0073e9SAndroid Build Coastguard WorkerHANDLED_FUNCTIONS_SUB_DIAGONAL = {} 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Workerdef implements_sub_diagonal(torch_function): 252*da0073e9SAndroid Build Coastguard Worker "Register a torch function override for SubDiagonalTensor" 253*da0073e9SAndroid Build Coastguard Worker @functools.wraps(torch_function) 254*da0073e9SAndroid Build Coastguard Worker def decorator(func): 255*da0073e9SAndroid Build Coastguard Worker HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func 256*da0073e9SAndroid Build Coastguard Worker return func 257*da0073e9SAndroid Build Coastguard Worker return decorator 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Workerclass SubDiagonalTensor(DiagonalTensor): 260*da0073e9SAndroid Build Coastguard Worker """A subclass of ``DiagonalTensor`` to test custom dispatch 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker This class tests semantics for defining ``__torch_function__`` on a 263*da0073e9SAndroid Build Coastguard Worker subclass of another class that defines ``__torch_function__``. The 264*da0073e9SAndroid Build Coastguard Worker only difference compared with the superclass is that this class 265*da0073e9SAndroid Build Coastguard Worker provides a slightly different repr as well as custom implementations 266*da0073e9SAndroid Build Coastguard Worker of ``mean`` and ``mm``, scaling the mean by a factor of 10 and 267*da0073e9SAndroid Build Coastguard Worker returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does. 268*da0073e9SAndroid Build Coastguard Worker """ 269*da0073e9SAndroid Build Coastguard Worker handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 272*da0073e9SAndroid Build Coastguard Worker return f"SubDiagonalTensor(N={self._N}, value={self._i})" 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(torch.mean) 276*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_mean(mat): 277*da0073e9SAndroid Build Coastguard Worker return 10 * float(mat._i) / mat._N 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(bar) 280*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_bar(mat): 281*da0073e9SAndroid Build Coastguard Worker return 0 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(torch.mm) 284*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_mm(mat1, mat2): 285*da0073e9SAndroid Build Coastguard Worker return 1 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(torch.div) 288*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_div(input, other, out=None): 289*da0073e9SAndroid Build Coastguard Worker return NotImplemented 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker@implements_sub_diagonal(foo) 292*da0073e9SAndroid Build Coastguard Workerdef sub_diagonal_foo(a, b, c=None): 293*da0073e9SAndroid Build Coastguard Worker return NotImplemented 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker# The dispatch table for SubDiagonalTensor's __torch_function__ implementation. 296*da0073e9SAndroid Build Coastguard WorkerHANDLED_FUNCTIONS_TENSOR_LIKE = {} 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker# Note: _triggered wrapper 300*da0073e9SAndroid Build Coastguard Worker# Dict that wraps the implementations from get_testing_overrides into another 301*da0073e9SAndroid Build Coastguard Worker# function with a _triggered slot/flag. The triggered flag is set when the 302*da0073e9SAndroid Build Coastguard Worker# implementation is called. 303*da0073e9SAndroid Build Coastguard WorkerWRAPPED_TRIGGERED_IMPLS = {} 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Workerdef triggered_wrapper(f): 307*da0073e9SAndroid Build Coastguard Worker @functools.wraps(f) 308*da0073e9SAndroid Build Coastguard Worker def wrapped(*args, **kwargs): 309*da0073e9SAndroid Build Coastguard Worker wrapped._triggered = True 310*da0073e9SAndroid Build Coastguard Worker return f(*args, **kwargs) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker wrapped._triggered = False 313*da0073e9SAndroid Build Coastguard Worker return wrapped 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Workerdef implements_tensor_like(torch_function): 316*da0073e9SAndroid Build Coastguard Worker "Register a torch function override for TensorLike" 317*da0073e9SAndroid Build Coastguard Worker @functools.wraps(torch_function) 318*da0073e9SAndroid Build Coastguard Worker def decorator(func): 319*da0073e9SAndroid Build Coastguard Worker HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func 320*da0073e9SAndroid Build Coastguard Worker return func 321*da0073e9SAndroid Build Coastguard Worker return decorator 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Workerdef generate_tensor_like_torch_implementations(): 324*da0073e9SAndroid Build Coastguard Worker torch_vars = vars(torch) 325*da0073e9SAndroid Build Coastguard Worker untested_funcs = [] 326*da0073e9SAndroid Build Coastguard Worker testing_overrides = get_testing_overrides() 327*da0073e9SAndroid Build Coastguard Worker # test/test_cpp_api_parity.py monkeypatches torch.nn to have a new 328*da0073e9SAndroid Build Coastguard Worker # function sample_functional. Depending on what order you run pytest 329*da0073e9SAndroid Build Coastguard Worker # collection, this may trigger the error here. This is a hack to fix 330*da0073e9SAndroid Build Coastguard Worker # the problem. A more proper fix is to make the "not tested" check 331*da0073e9SAndroid Build Coastguard Worker # a test on its own, and to make sure the monkeypatch is only installed 332*da0073e9SAndroid Build Coastguard Worker # for the span of the relevant test (and deleted afterwards) 333*da0073e9SAndroid Build Coastguard Worker testing_ignore = {"sample_functional", "autocast"} 334*da0073e9SAndroid Build Coastguard Worker for namespace, funcs in get_overridable_functions().items(): 335*da0073e9SAndroid Build Coastguard Worker for func in funcs: 336*da0073e9SAndroid Build Coastguard Worker if func not in testing_overrides and func.__name__ not in testing_ignore: 337*da0073e9SAndroid Build Coastguard Worker untested_funcs.append(f"{namespace}.{func.__name__}") 338*da0073e9SAndroid Build Coastguard Worker msg = ( 339*da0073e9SAndroid Build Coastguard Worker "The following functions are not tested for __torch_function__ " 340*da0073e9SAndroid Build Coastguard Worker "support, please ensure there is an entry in the dict returned by " 341*da0073e9SAndroid Build Coastguard Worker "torch.overrides.get_testing_overrides for this function or if a " 342*da0073e9SAndroid Build Coastguard Worker "__torch_function__ override does not make sense, add an entry to " 343*da0073e9SAndroid Build Coastguard Worker "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}" 344*da0073e9SAndroid Build Coastguard Worker ) 345*da0073e9SAndroid Build Coastguard Worker assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) 346*da0073e9SAndroid Build Coastguard Worker for func, override in testing_overrides.items(): 347*da0073e9SAndroid Build Coastguard Worker # decorate the overrides with implements_tensor_like if it's not a 348*da0073e9SAndroid Build Coastguard Worker # torch.Tensor method 349*da0073e9SAndroid Build Coastguard Worker wrapped = triggered_wrapper(override) 350*da0073e9SAndroid Build Coastguard Worker # See note: "_triggered wrapper" 351*da0073e9SAndroid Build Coastguard Worker WRAPPED_TRIGGERED_IMPLS[func] = wrapped 352*da0073e9SAndroid Build Coastguard Worker if is_tensor_method_or_property(func): 353*da0073e9SAndroid Build Coastguard Worker implements_sub(func)(wrapped) 354*da0073e9SAndroid Build Coastguard Worker else: 355*da0073e9SAndroid Build Coastguard Worker implements_tensor_like(func)(wrapped) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Workergenerate_tensor_like_torch_implementations() 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Workerclass TensorLike: 360*da0073e9SAndroid Build Coastguard Worker """A class that overrides the full torch API 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker This class is used to explicitly test that the full torch.tensor API 363*da0073e9SAndroid Build Coastguard Worker can be overriden with a class that defines __torch_function__. 364*da0073e9SAndroid Build Coastguard Worker """ 365*da0073e9SAndroid Build Coastguard Worker @classmethod 366*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 367*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 368*da0073e9SAndroid Build Coastguard Worker kwargs = {} 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker if func not in HANDLED_FUNCTIONS_TENSOR_LIKE: 371*da0073e9SAndroid Build Coastguard Worker return NotImplemented 372*da0073e9SAndroid Build Coastguard Worker # In this case _torch_function_ should override TensorLike objects 373*da0073e9SAndroid Build Coastguard Worker return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs) 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Workerclass TestTorchFunctionOverride(TestCase): 376*da0073e9SAndroid Build Coastguard Worker @classmethod 377*da0073e9SAndroid Build Coastguard Worker def setUpClass(cls): 378*da0073e9SAndroid Build Coastguard Worker cls._stack = contextlib.ExitStack() 379*da0073e9SAndroid Build Coastguard Worker if TEST_WITH_TORCHDYNAMO: 380*da0073e9SAndroid Build Coastguard Worker # Add classes to the wrapped tensor subclasses 381*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 382*da0073e9SAndroid Build Coastguard Worker def setup_subclasses(): 383*da0073e9SAndroid Build Coastguard Worker old = set(torch._dynamo.config.traceable_tensor_subclasses) 384*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor) 385*da0073e9SAndroid Build Coastguard Worker try: 386*da0073e9SAndroid Build Coastguard Worker yield 387*da0073e9SAndroid Build Coastguard Worker finally: 388*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.traceable_tensor_subclasses.clear() 389*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.traceable_tensor_subclasses.update(old) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker cls._stack.enter_context(setup_subclasses()) 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker @classmethod 394*da0073e9SAndroid Build Coastguard Worker def tearDownClass(cls): 395*da0073e9SAndroid Build Coastguard Worker cls._stack.close() 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker def test_mean_semantics(self): 398*da0073e9SAndroid Build Coastguard Worker """Test that a function with one argument can be overridden""" 399*da0073e9SAndroid Build Coastguard Worker t1 = DiagonalTensor(5, 2) 400*da0073e9SAndroid Build Coastguard Worker t2 = SubTensor([[1, 2], [1, 2]]) 401*da0073e9SAndroid Build Coastguard Worker t3 = SubDiagonalTensor(5, 2) 402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mean(t1), 0.4) 403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar(t1), -1) 404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mean(t2), 0) 405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar(t2), 1) 406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mean(t3), 4.0) 407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar(t3), 0) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker def test_has_torch_function_non_sequence(self): 410*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "expected a sequence"): 411*da0073e9SAndroid Build Coastguard Worker has_torch_function(object()) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker def test_mm_semantics(self): 414*da0073e9SAndroid Build Coastguard Worker """Test that a function with multiple arguments can be overridden""" 415*da0073e9SAndroid Build Coastguard Worker t1 = DiagonalTensor(5, 2) 416*da0073e9SAndroid Build Coastguard Worker t2 = torch.eye(5) * 2 417*da0073e9SAndroid Build Coastguard Worker t3 = SubTensor([[1, 2], [1, 2]]) 418*da0073e9SAndroid Build Coastguard Worker t4 = SubDiagonalTensor(5, 2) 419*da0073e9SAndroid Build Coastguard Worker # only DiagonalTensor so should always get DiagonalTensor result 420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t1, t1), 0) 421*da0073e9SAndroid Build Coastguard Worker # tensor and DiagonalTensor, always return DiagonalTensor result 422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t1, t2), 0) 423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t2, t1), 0) 424*da0073e9SAndroid Build Coastguard Worker # only SubTensor so should always get SubTensor result 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t3, t3), -1) 426*da0073e9SAndroid Build Coastguard Worker # tensor and SubTensor so should always get SubTensor result 427*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t3, t2), -1) 428*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t2, t3), -1) 429*da0073e9SAndroid Build Coastguard Worker # DiagonalTensor and SubTensor are unrelated classes so the result 430*da0073e9SAndroid Build Coastguard Worker # depends on which argument appears first 431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t3, t1), -1) 432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t1, t3), 0) 433*da0073e9SAndroid Build Coastguard Worker # SubDiagonalTensor should take precedence over DiagonalTensor 434*da0073e9SAndroid Build Coastguard Worker # but should behave otherwise the same as DiagonalTensor 435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t4, t4), 1) 436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t4, t1), 1) 437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t1, t4), 1) 438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t4, t2), 1) 439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t2, t4), 1) 440*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t3, t4), -1) 441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(t4, t3), 1) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker def test_precedence_semantics(self): 444*da0073e9SAndroid Build Coastguard Worker """Test semantics for __torch_function__ for functions that take 445*da0073e9SAndroid Build Coastguard Worker multiple arguments 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker For functions that take multiple arguments, the appropriate 448*da0073e9SAndroid Build Coastguard Worker __torch_function__ implementation to call is determined by 449*da0073e9SAndroid Build Coastguard Worker examining the types of the arguments. The precedence order is 450*da0073e9SAndroid Build Coastguard Worker left-to-right in the argument list, except subclasses are always 451*da0073e9SAndroid Build Coastguard Worker checked before superclasses. The first result of calling the 452*da0073e9SAndroid Build Coastguard Worker implementations in precedence order that is not NotImplemented 453*da0073e9SAndroid Build Coastguard Worker is returned to the user. If all implementations return 454*da0073e9SAndroid Build Coastguard Worker NotImplemented, a TypeError is raised. 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker All cases are tested with functions implemented in C++ and 457*da0073e9SAndroid Build Coastguard Worker either foo or baz, which are python functions defined above that 458*da0073e9SAndroid Build Coastguard Worker are instrumented to obey the same dispatch rules as the 459*da0073e9SAndroid Build Coastguard Worker functions in torch.functional. 460*da0073e9SAndroid Build Coastguard Worker """ 461*da0073e9SAndroid Build Coastguard Worker # DiagonalTensor has a valid override and SubDiagonal has an 462*da0073e9SAndroid Build Coastguard Worker # override that returns NotImplemented so we should call the 463*da0073e9SAndroid Build Coastguard Worker # DiagonalTensor implementation, returning -1 464*da0073e9SAndroid Build Coastguard Worker t1 = DiagonalTensor(5, 2) 465*da0073e9SAndroid Build Coastguard Worker t2 = SubDiagonalTensor(5, 2) 466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.div(t1, t2), -1) 467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.div(t2, t1), -1) 468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(t1, t2), -1) 469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(t2, t1), -1) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker # SubTensor has an implementation that returns NotImplemented as 472*da0073e9SAndroid Build Coastguard Worker # well so it should behave exactly like SubDiagonalTensor in the 473*da0073e9SAndroid Build Coastguard Worker # test above 474*da0073e9SAndroid Build Coastguard Worker t3 = SubTensor([[1, 2], [1, 2]]) 475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.div(t1, t3), -1) 476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.div(t3, t1), -1) 477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(t1, t3), -1) 478*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(t3, t1), -1) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker # div between SubTensor and SubDiagonalTensor should raise 481*da0073e9SAndroid Build Coastguard Worker # TypeError since both have an implementation that 482*da0073e9SAndroid Build Coastguard Worker # explicitly returns NotImplemented 483*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 484*da0073e9SAndroid Build Coastguard Worker torch.div(t2, t3) 485*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 486*da0073e9SAndroid Build Coastguard Worker torch.div(t3, t2) 487*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 488*da0073e9SAndroid Build Coastguard Worker foo(t2, t3) 489*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 490*da0073e9SAndroid Build Coastguard Worker foo(t3, t2) 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker # none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a 493*da0073e9SAndroid Build Coastguard Worker # mul or a baz implementation so all ops should raise TypeError 494*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 495*da0073e9SAndroid Build Coastguard Worker torch.mul(t1, t1) 496*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 497*da0073e9SAndroid Build Coastguard Worker torch.mul(t1, t2) 498*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 499*da0073e9SAndroid Build Coastguard Worker torch.mul(t1, t3) 500*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 501*da0073e9SAndroid Build Coastguard Worker torch.mul(t2, t1) 502*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 503*da0073e9SAndroid Build Coastguard Worker torch.mul(t2, t2) 504*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 505*da0073e9SAndroid Build Coastguard Worker torch.mul(t2, t3) 506*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 507*da0073e9SAndroid Build Coastguard Worker torch.mul(t3, t1) 508*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 509*da0073e9SAndroid Build Coastguard Worker torch.mul(t3, t2) 510*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 511*da0073e9SAndroid Build Coastguard Worker torch.mul(t3, t3) 512*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 513*da0073e9SAndroid Build Coastguard Worker baz(t1, t1) 514*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 515*da0073e9SAndroid Build Coastguard Worker baz(t1, t2) 516*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 517*da0073e9SAndroid Build Coastguard Worker baz(t1, t3) 518*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 519*da0073e9SAndroid Build Coastguard Worker baz(t2, t1) 520*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 521*da0073e9SAndroid Build Coastguard Worker baz(t2, t2) 522*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 523*da0073e9SAndroid Build Coastguard Worker baz(t2, t3) 524*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 525*da0073e9SAndroid Build Coastguard Worker baz(t3, t1) 526*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 527*da0073e9SAndroid Build Coastguard Worker baz(t3, t2) 528*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 529*da0073e9SAndroid Build Coastguard Worker baz(t3, t3) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker def test_user_implementation_raises(self): 532*da0073e9SAndroid Build Coastguard Worker """Test that errors raised in user implementations propagate correctly""" 533*da0073e9SAndroid Build Coastguard Worker t1 = DiagonalTensor(5, 2) 534*da0073e9SAndroid Build Coastguard Worker t2 = DiagonalTensor(5, 2) 535*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 536*da0073e9SAndroid Build Coastguard Worker torch.add(t1, t2) 537*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 538*da0073e9SAndroid Build Coastguard Worker quux(t1) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker def test_tensor_subclass_propagation(self): 541*da0073e9SAndroid Build Coastguard Worker """this test exercises the functionality described in 542*da0073e9SAndroid Build Coastguard Worker docs/source/notes/extending.rst#subclassing-torchtensor""" 543*da0073e9SAndroid Build Coastguard Worker t1 = torch.tensor([5]) 544*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor([6]) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker s1 = SubTensor2([5]) 547*da0073e9SAndroid Build Coastguard Worker s2 = SubTensor2([6]) 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker ss1 = SubSubTensor2([5]) 550*da0073e9SAndroid Build Coastguard Worker ss2 = SubSubTensor2([6]) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker sn1 = SubTensor3([5]) 553*da0073e9SAndroid Build Coastguard Worker sn2 = SubTensor3([6]) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker # Check that leaf subclass is kept regardless of order 556*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1 + t2, SubTensor2)) 557*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(t1 + s2, SubTensor2)) 558*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1 + s2, SubTensor2)) 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker # Check indexing subclass is kept 561*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1[0], SubTensor2)) 562*da0073e9SAndroid Build Coastguard Worker 563*da0073e9SAndroid Build Coastguard Worker # Check case for subclass of subclass. 564*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2)) 565*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(ss1 + s2, SubSubTensor2)) 566*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(s1 + ss2, SubSubTensor2)) 567*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2)) 568*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(ss1 + t2, SubSubTensor2)) 569*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(t1 + ss2, SubSubTensor2)) 570*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(ss1[0], SubSubTensor2)) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker # Make sure unrelated class trees are not merged. 573*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 574*da0073e9SAndroid Build Coastguard Worker s1 + sn2 575*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 576*da0073e9SAndroid Build Coastguard Worker sn1 + s2 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker def test_base(self): 579*da0073e9SAndroid Build Coastguard Worker # https://github.com/szagoruyko/pytorchviz/issues/65 580*da0073e9SAndroid Build Coastguard Worker class DummyTensor(torch.Tensor): 581*da0073e9SAndroid Build Coastguard Worker pass 582*da0073e9SAndroid Build Coastguard Worker 583*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1) 584*da0073e9SAndroid Build Coastguard Worker c = DummyTensor(a) 585*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c._is_view()) 586*da0073e9SAndroid Build Coastguard Worker self.assertTrue(c._base is a) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker def test_grad(self): 589*da0073e9SAndroid Build Coastguard Worker # Previously, Tensor-like objects that did not subclass from Tensor 590*da0073e9SAndroid Build Coastguard Worker # did not get wrapped into unary tuples before being passed into 591*da0073e9SAndroid Build Coastguard Worker # handle_torch_function, in contradiction with how Tensor-likes 592*da0073e9SAndroid Build Coastguard Worker # were handled 593*da0073e9SAndroid Build Coastguard Worker # 594*da0073e9SAndroid Build Coastguard Worker # NB: this asserts that the arguments get normalized into a tuple 595*da0073e9SAndroid Build Coastguard Worker # before entering the torch function handler; it could go the 596*da0073e9SAndroid Build Coastguard Worker # other way but beware https://github.com/pytorch/pytorch/issues/76037 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker class Dummy: 599*da0073e9SAndroid Build Coastguard Worker @classmethod 600*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 601*da0073e9SAndroid Build Coastguard Worker inputs, outputs = args 602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputs, (x,)) 603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, (x,)) 604*da0073e9SAndroid Build Coastguard Worker return -1 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker x = Dummy() 607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.autograd.grad(x, x), -1) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker def test_pow_rpow(self): 610*da0073e9SAndroid Build Coastguard Worker class NothingImplemented(torch.Tensor): 611*da0073e9SAndroid Build Coastguard Worker @classmethod 612*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 613*da0073e9SAndroid Build Coastguard Worker return NotImplemented 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker class RPowOnly(torch.Tensor): 616*da0073e9SAndroid Build Coastguard Worker @classmethod 617*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 618*da0073e9SAndroid Build Coastguard Worker if func is torch.Tensor.__rpow__: 619*da0073e9SAndroid Build Coastguard Worker return -1 620*da0073e9SAndroid Build Coastguard Worker return NotImplemented 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(NothingImplemented() ** RPowOnly(), -1) 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Workerdef generate_tensor_like_override_tests(cls): 626*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.generated.annotated_fn_args import annotated_args 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker def test_generator(func, override): 629*da0073e9SAndroid Build Coastguard Worker # If func corresponds to a torch.Tensor method or property. 630*da0073e9SAndroid Build Coastguard Worker if is_tensor_method_or_property(func): 631*da0073e9SAndroid Build Coastguard Worker # Generate an instance by using SubTensor, 632*da0073e9SAndroid Build Coastguard Worker def instance_gen(): 633*da0073e9SAndroid Build Coastguard Worker return SubTensor([5]) 634*da0073e9SAndroid Build Coastguard Worker else: 635*da0073e9SAndroid Build Coastguard Worker # Otherwise, TensorLike. 636*da0073e9SAndroid Build Coastguard Worker def instance_gen(): 637*da0073e9SAndroid Build Coastguard Worker return TensorLike() 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker # FIXME The following code does not support kwonly args without defaults. 640*da0073e9SAndroid Build Coastguard Worker # The fix is easy, as one just needs to save these args when generating the variable 641*da0073e9SAndroid Build Coastguard Worker # annotated_args. The problem is that, if one does so, one finds a number 642*da0073e9SAndroid Build Coastguard Worker # of functions that have problematic signatures in native_functions.yaml. 643*da0073e9SAndroid Build Coastguard Worker # Fixing these would be BC breaking, so hence this terrible hack 644*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/67008 645*da0073e9SAndroid Build Coastguard Worker kwargs = {} 646*da0073e9SAndroid Build Coastguard Worker if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__: 647*da0073e9SAndroid Build Coastguard Worker kwargs = {"upper": True} 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker func_args = [] 650*da0073e9SAndroid Build Coastguard Worker is_method = is_tensor_method_or_property(func) 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Worker def _simple_type_parser(func, arg_name, arg_type): 653*da0073e9SAndroid Build Coastguard Worker # Guess valid input to aten function based on type of argument 654*da0073e9SAndroid Build Coastguard Worker if arg_type == "Tensor": 655*da0073e9SAndroid Build Coastguard Worker return instance_gen() 656*da0073e9SAndroid Build Coastguard Worker elif arg_type == "TensorList" or arg_type == "ITensorListRef": 657*da0073e9SAndroid Build Coastguard Worker return [instance_gen(), instance_gen()] 658*da0073e9SAndroid Build Coastguard Worker elif arg_type == "c10::List<::std::optional<Tensor>>": 659*da0073e9SAndroid Build Coastguard Worker return [instance_gen(), instance_gen()] 660*da0073e9SAndroid Build Coastguard Worker elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef": 661*da0073e9SAndroid Build Coastguard Worker size = arg.get("size", 2) 662*da0073e9SAndroid Build Coastguard Worker if size == 1: 663*da0073e9SAndroid Build Coastguard Worker return 1 664*da0073e9SAndroid Build Coastguard Worker else: 665*da0073e9SAndroid Build Coastguard Worker return [1] * size 666*da0073e9SAndroid Build Coastguard Worker elif arg_type == "Scalar": 667*da0073e9SAndroid Build Coastguard Worker return 3.5 668*da0073e9SAndroid Build Coastguard Worker elif arg_type == "bool": 669*da0073e9SAndroid Build Coastguard Worker return False 670*da0073e9SAndroid Build Coastguard Worker elif arg_type == "Dimname": 671*da0073e9SAndroid Build Coastguard Worker return "" 672*da0073e9SAndroid Build Coastguard Worker elif arg_type == "DimnameList": 673*da0073e9SAndroid Build Coastguard Worker return [""] 674*da0073e9SAndroid Build Coastguard Worker elif arg_type.startswith("int"): 675*da0073e9SAndroid Build Coastguard Worker return 0 676*da0073e9SAndroid Build Coastguard Worker elif arg_type in {"Stream"}: 677*da0073e9SAndroid Build Coastguard Worker return torch.Stream() 678*da0073e9SAndroid Build Coastguard Worker elif arg_type.startswith("float") or arg_type == "double": 679*da0073e9SAndroid Build Coastguard Worker return 1.0 680*da0073e9SAndroid Build Coastguard Worker elif arg_type in {"Generator", "MemoryFormat", "TensorOptions"}: 681*da0073e9SAndroid Build Coastguard Worker return None 682*da0073e9SAndroid Build Coastguard Worker elif arg_type == "ScalarType": 683*da0073e9SAndroid Build Coastguard Worker return torch.float32 684*da0073e9SAndroid Build Coastguard Worker elif arg_type == "c10::string_view": 685*da0073e9SAndroid Build Coastguard Worker return "" 686*da0073e9SAndroid Build Coastguard Worker elif arg_type == "SymInt": 687*da0073e9SAndroid Build Coastguard Worker # TODO: generate actual SymbolicInt 688*da0073e9SAndroid Build Coastguard Worker return 1 689*da0073e9SAndroid Build Coastguard Worker else: 690*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 691*da0073e9SAndroid Build Coastguard Worker f"Unsupported argument type {arg_type} for {arg_name} of function {func}" 692*da0073e9SAndroid Build Coastguard Worker ) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker if func in annotated_args: 695*da0073e9SAndroid Build Coastguard Worker for arg in annotated_args[func]: 696*da0073e9SAndroid Build Coastguard Worker # Guess valid input to aten function based on type of argument 697*da0073e9SAndroid Build Coastguard Worker t = arg["simple_type"] 698*da0073e9SAndroid Build Coastguard Worker if t.endswith("?"): 699*da0073e9SAndroid Build Coastguard Worker t = t[:-1] 700*da0073e9SAndroid Build Coastguard Worker if t == "Tensor" and is_method and arg["name"] == "self": 701*da0073e9SAndroid Build Coastguard Worker # See "Note: properties and __get__" 702*da0073e9SAndroid Build Coastguard Worker func = func.__get__(instance_gen()) 703*da0073e9SAndroid Build Coastguard Worker continue 704*da0073e9SAndroid Build Coastguard Worker arg_to_add = _simple_type_parser(func, arg["name"], t) 705*da0073e9SAndroid Build Coastguard Worker if "is_kwarg_only" in arg and arg["is_kwarg_only"] == str(True): 706*da0073e9SAndroid Build Coastguard Worker kwargs[arg["name"]] = arg_to_add 707*da0073e9SAndroid Build Coastguard Worker else: 708*da0073e9SAndroid Build Coastguard Worker func_args.append(arg_to_add) 709*da0073e9SAndroid Build Coastguard Worker else: 710*da0073e9SAndroid Build Coastguard Worker args = inspect.getfullargspec(override) 711*da0073e9SAndroid Build Coastguard Worker try: 712*da0073e9SAndroid Build Coastguard Worker func_args = inspect.getfullargspec(func) 713*da0073e9SAndroid Build Coastguard Worker # Remove annotations from argspec 714*da0073e9SAndroid Build Coastguard Worker func_args = type(func_args)(**{**func_args, 'annotations': None}) 715*da0073e9SAndroid Build Coastguard Worker if func_args != args: 716*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Override for {func} doesn't match its argspec.\n" 717*da0073e9SAndroid Build Coastguard Worker + f"Original: {inspect.signature(func)}\n" 718*da0073e9SAndroid Build Coastguard Worker + f"Override: {inspect.signature(override)}") 719*da0073e9SAndroid Build Coastguard Worker except TypeError: 720*da0073e9SAndroid Build Coastguard Worker pass 721*da0073e9SAndroid Build Coastguard Worker nargs = len(args.args) 722*da0073e9SAndroid Build Coastguard Worker if args.defaults is not None: 723*da0073e9SAndroid Build Coastguard Worker nargs -= len(args.defaults) 724*da0073e9SAndroid Build Coastguard Worker func_args = [instance_gen() for _ in range(nargs)] 725*da0073e9SAndroid Build Coastguard Worker if args.varargs is not None: 726*da0073e9SAndroid Build Coastguard Worker func_args += [instance_gen(), instance_gen()] 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker def test(self): 729*da0073e9SAndroid Build Coastguard Worker ret = func(*func_args, **kwargs) 730*da0073e9SAndroid Build Coastguard Worker # ret is None for certain protocols, e.g., `__weakref__` and `__setitem__` 731*da0073e9SAndroid Build Coastguard Worker # This is currently the best check but doesn't work for, for example, 732*da0073e9SAndroid Build Coastguard Worker # Tensor.__add__ because it redirects to Tensor.add. 733*da0073e9SAndroid Build Coastguard Worker # See note "_triggered wrapper" 734*da0073e9SAndroid Build Coastguard Worker if not is_method or ret is None: 735*da0073e9SAndroid Build Coastguard Worker self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered) 736*da0073e9SAndroid Build Coastguard Worker return 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret, -1) 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker return test 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker for func, override in get_testing_overrides().items(): 743*da0073e9SAndroid Build Coastguard Worker test_method = test_generator(func, override) 744*da0073e9SAndroid Build Coastguard Worker if func.__name__ == "__get__": 745*da0073e9SAndroid Build Coastguard Worker # Note: properties and __get__ 746*da0073e9SAndroid Build Coastguard Worker # __get__ is part of the descriptor protocol. 747*da0073e9SAndroid Build Coastguard Worker # https://docs.python.org/3/howto/descriptor.html 748*da0073e9SAndroid Build Coastguard Worker # This is used for properties of the form 749*da0073e9SAndroid Build Coastguard Worker # torch.Tensor.<property>, with the method __get__ 750*da0073e9SAndroid Build Coastguard Worker # In this case we get the property name in two ways: 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker # This case for properties defined in C. 753*da0073e9SAndroid Build Coastguard Worker module = getattr( 754*da0073e9SAndroid Build Coastguard Worker func.__self__, 755*da0073e9SAndroid Build Coastguard Worker "__qualname__", 756*da0073e9SAndroid Build Coastguard Worker None 757*da0073e9SAndroid Build Coastguard Worker ) 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker # This one for properties defined in Python. 760*da0073e9SAndroid Build Coastguard Worker if module is None: 761*da0073e9SAndroid Build Coastguard Worker module = "Tensor." + func.__self__.fget.__name__ 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Worker # Unfortunately I couldn't find a way to unify these two cases 764*da0073e9SAndroid Build Coastguard Worker # and there is no way for general descriptors. 765*da0073e9SAndroid Build Coastguard Worker elif is_tensor_method_or_property(func): 766*da0073e9SAndroid Build Coastguard Worker module = "Tensor" 767*da0073e9SAndroid Build Coastguard Worker else: 768*da0073e9SAndroid Build Coastguard Worker module = func.__module__ 769*da0073e9SAndroid Build Coastguard Worker if module: 770*da0073e9SAndroid Build Coastguard Worker name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__) 771*da0073e9SAndroid Build Coastguard Worker else: 772*da0073e9SAndroid Build Coastguard Worker name = f'test_{func.__name__}' 773*da0073e9SAndroid Build Coastguard Worker test_method.__name__ = name 774*da0073e9SAndroid Build Coastguard Worker setattr(cls, name, test_method) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Workergenerate_tensor_like_override_tests(TestTorchFunctionOverride) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Workerclass Wrapper: 779*da0073e9SAndroid Build Coastguard Worker "Basic data container that knows how to unwrap itself" 780*da0073e9SAndroid Build Coastguard Worker def __init__(self, data): 781*da0073e9SAndroid Build Coastguard Worker self.__dict__["_data"] = data 782*da0073e9SAndroid Build Coastguard Worker self.__dict__["used_attrs"] = set() 783*da0073e9SAndroid Build Coastguard Worker self.__dict__["used_calls"] = set() 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 786*da0073e9SAndroid Build Coastguard Worker if name in self.__dict__: 787*da0073e9SAndroid Build Coastguard Worker return self.__dict__[name] 788*da0073e9SAndroid Build Coastguard Worker self.used_attrs.add(name) 789*da0073e9SAndroid Build Coastguard Worker 790*da0073e9SAndroid Build Coastguard Worker val = getattr(self._data, name) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker # If it's a method 793*da0073e9SAndroid Build Coastguard Worker if not isinstance(val, torch.device) and callable(val): 794*da0073e9SAndroid Build Coastguard Worker c = getattr(type(self._data), name) 795*da0073e9SAndroid Build Coastguard Worker # Don't append self to args if classmethod/staticmethod 796*da0073e9SAndroid Build Coastguard Worker if c is val: 797*da0073e9SAndroid Build Coastguard Worker return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw)) 798*da0073e9SAndroid Build Coastguard Worker # Otherwise append self to args 799*da0073e9SAndroid Build Coastguard Worker return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw)) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker return wrap(val) 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker def __setattr__(self, name, value): 804*da0073e9SAndroid Build Coastguard Worker if name in self.__dict__: 805*da0073e9SAndroid Build Coastguard Worker self.__dict__[name] = value 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker self.used_attrs.add(name) 808*da0073e9SAndroid Build Coastguard Worker setattr(self._data, name, unwrap(value)) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker def __setitem__(self, key, value): 811*da0073e9SAndroid Build Coastguard Worker self._data[unwrap(key)] = unwrap(value) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 814*da0073e9SAndroid Build Coastguard Worker return wrap(self._data[unwrap(key)]) 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker @classmethod 817*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 818*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 819*da0073e9SAndroid Build Coastguard Worker kwargs = {} 820*da0073e9SAndroid Build Coastguard Worker # Find an instance of this class in the arguments 821*da0073e9SAndroid Build Coastguard Worker args_of_this_cls = [] 822*da0073e9SAndroid Build Coastguard Worker for a in args: 823*da0073e9SAndroid Build Coastguard Worker if isinstance(a, cls): 824*da0073e9SAndroid Build Coastguard Worker args_of_this_cls.append(a) 825*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, collections.abc.Sequence): 826*da0073e9SAndroid Build Coastguard Worker args_of_this_cls.extend(el for el in a if isinstance(el, cls)) 827*da0073e9SAndroid Build Coastguard Worker assert len(args_of_this_cls) > 0 828*da0073e9SAndroid Build Coastguard Worker for a in args_of_this_cls: 829*da0073e9SAndroid Build Coastguard Worker a.used_calls.add(func) 830*da0073e9SAndroid Build Coastguard Worker args = unwrap(tuple(args)) 831*da0073e9SAndroid Build Coastguard Worker kwargs = {k: unwrap(v) for k, v in kwargs.items()} 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker return wrap(func(*args, **kwargs)) 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker def __add__(self, other): 836*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.add, (Wrapper,), (self, other)) 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker def __mul__(self, other): 839*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.mul, (Wrapper,), (self, other)) 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker def __sub__(self, other): 842*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.sub, (Wrapper,), (self, other)) 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker def __truediv__(self, other): 845*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other)) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker def __floordiv__(self, other): 848*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other)) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker def __ge__(self, other): 851*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.ge, (Wrapper,), (self, other)) 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker def __gt__(self, other): 854*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.gt, (Wrapper,), (self, other)) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker def __lt__(self, other): 857*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.lt, (Wrapper,), (self, other)) 858*da0073e9SAndroid Build Coastguard Worker 859*da0073e9SAndroid Build Coastguard Worker def __le__(self, other): 860*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.le, (Wrapper,), (self, other)) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 863*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.eq, (Wrapper,), (self, other)) 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker def __ne__(self, other): 866*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.ne, (Wrapper,), (self, other)) 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker def __bool__(self): 869*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,)) 870*da0073e9SAndroid Build Coastguard Worker 871*da0073e9SAndroid Build Coastguard Worker def __int__(self): 872*da0073e9SAndroid Build Coastguard Worker return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,)) 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker def __len__(self): 875*da0073e9SAndroid Build Coastguard Worker return len(self._data) 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker# unwrap inputs if necessary 879*da0073e9SAndroid Build Coastguard Workerdef unwrap(v): 880*da0073e9SAndroid Build Coastguard Worker if type(v) in {tuple, list}: 881*da0073e9SAndroid Build Coastguard Worker return type(v)(unwrap(vi) for vi in v) 882*da0073e9SAndroid Build Coastguard Worker 883*da0073e9SAndroid Build Coastguard Worker return v._data if isinstance(v, Wrapper) else v 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker# wrap inputs if necessary 886*da0073e9SAndroid Build Coastguard Workerdef wrap(v): 887*da0073e9SAndroid Build Coastguard Worker if type(v) in {tuple, list}: 888*da0073e9SAndroid Build Coastguard Worker return type(v)(wrap(vi) for vi in v) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker return Wrapper(v) if isinstance(v, torch.Tensor) else v 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Workerclass TestEinsumOverride(TestCase): 893*da0073e9SAndroid Build Coastguard Worker "Regression test for gh-38479" 894*da0073e9SAndroid Build Coastguard Worker def test_wrapper(self): 895*da0073e9SAndroid Build Coastguard Worker x = Wrapper(torch.randn(5)) 896*da0073e9SAndroid Build Coastguard Worker y = Wrapper(torch.randn(4)) 897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.einsum('i,j->ij', x, y)._data, 898*da0073e9SAndroid Build Coastguard Worker torch.ger(x, y)._data) 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker # in the old einsum interface, `operands` is a list 901*da0073e9SAndroid Build Coastguard Worker a = Wrapper(torch.randn(2, 3)) 902*da0073e9SAndroid Build Coastguard Worker b = Wrapper(torch.randn(5, 3, 7)) 903*da0073e9SAndroid Build Coastguard Worker c = Wrapper(torch.randn(2, 7)) 904*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data, 905*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.bilinear(a, c, b)._data) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Workerclass TestGradCheckOverride(TestCase): 908*da0073e9SAndroid Build Coastguard Worker "Test that wrappers work with gradcheck." 909*da0073e9SAndroid Build Coastguard Worker def test_gradcheck(self): 910*da0073e9SAndroid Build Coastguard Worker from torch.testing._internal.common_utils import gradcheck, gradgradcheck 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker def run_test(fast_mode): 913*da0073e9SAndroid Build Coastguard Worker a = wrap(torch.tensor(5.0, dtype=torch.double)) 914*da0073e9SAndroid Build Coastguard Worker b = wrap(torch.tensor(6.0, dtype=torch.double)) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker a.requires_grad = True 917*da0073e9SAndroid Build Coastguard Worker b.requires_grad = True 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode) 920*da0073e9SAndroid Build Coastguard Worker gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker total_used_attrs = a.used_attrs.union(b.used_attrs) 923*da0073e9SAndroid Build Coastguard Worker total_used_calls = a.used_calls.union(b.used_calls) 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker # These attributes (and the functions below) may change 926*da0073e9SAndroid Build Coastguard Worker # if the gradcheck implementation changes. It's best to 927*da0073e9SAndroid Build Coastguard Worker # aim for attributes that may be commonly present on other 928*da0073e9SAndroid Build Coastguard Worker # Tensor-likes. 929*da0073e9SAndroid Build Coastguard Worker expected_used_attrs = { 930*da0073e9SAndroid Build Coastguard Worker 'data', 931*da0073e9SAndroid Build Coastguard Worker 'dtype', 932*da0073e9SAndroid Build Coastguard Worker 'is_floating_point', 933*da0073e9SAndroid Build Coastguard Worker 'is_sparse', 934*da0073e9SAndroid Build Coastguard Worker 'layout', 935*da0073e9SAndroid Build Coastguard Worker 'new_zeros', 936*da0073e9SAndroid Build Coastguard Worker 'numel', 937*da0073e9SAndroid Build Coastguard Worker 'requires_grad', 938*da0073e9SAndroid Build Coastguard Worker 'requires_grad_', 939*da0073e9SAndroid Build Coastguard Worker 'size', 940*da0073e9SAndroid Build Coastguard Worker 'stride', 941*da0073e9SAndroid Build Coastguard Worker } 942*da0073e9SAndroid Build Coastguard Worker if fast_mode: 943*da0073e9SAndroid Build Coastguard Worker expected_used_attrs.add('is_complex') 944*da0073e9SAndroid Build Coastguard Worker expected_used_attrs.add('device') 945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_used_attrs, total_used_attrs) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker expected_used_calls = { 948*da0073e9SAndroid Build Coastguard Worker torch.Tensor.new_zeros, 949*da0073e9SAndroid Build Coastguard Worker torch.Tensor.size, 950*da0073e9SAndroid Build Coastguard Worker torch.Tensor.is_floating_point, 951*da0073e9SAndroid Build Coastguard Worker torch.Tensor.numel, 952*da0073e9SAndroid Build Coastguard Worker torch.Tensor.stride, 953*da0073e9SAndroid Build Coastguard Worker torch.Tensor.requires_grad_, 954*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad, 955*da0073e9SAndroid Build Coastguard Worker torch.add, 956*da0073e9SAndroid Build Coastguard Worker } 957*da0073e9SAndroid Build Coastguard Worker if fast_mode: 958*da0073e9SAndroid Build Coastguard Worker expected_used_calls.add(torch.Tensor.is_complex) 959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_used_calls, total_used_calls) 960*da0073e9SAndroid Build Coastguard Worker run_test(fast_mode=True) 961*da0073e9SAndroid Build Coastguard Worker run_test(fast_mode=False) 962*da0073e9SAndroid Build Coastguard Worker 963*da0073e9SAndroid Build Coastguard Workerclass TestNamedTuple(TestCase): 964*da0073e9SAndroid Build Coastguard Worker """ Regression test for gh-47090 """ 965*da0073e9SAndroid Build Coastguard Worker def test_max(self): 966*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 967*da0073e9SAndroid Build Coastguard Worker xs = x.as_subclass(SubTensor2) 968*da0073e9SAndroid Build Coastguard Worker r = torch.max(x, dim=0) 969*da0073e9SAndroid Build Coastguard Worker rs = torch.max(xs, dim=0) 970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(r), type(rs)) 971*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, rs) 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Workerclass TestGradNewOnesOverride(TestCase): 974*da0073e9SAndroid Build Coastguard Worker """ Regression test for gh-47069 """ 975*da0073e9SAndroid Build Coastguard Worker def test_newones(self): 976*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1, 2]).as_subclass(SubTensor2) 977*da0073e9SAndroid Build Coastguard Worker n = t.new_ones((1, 2)) 978*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(n), SubTensor2) 979*da0073e9SAndroid Build Coastguard Worker 980*da0073e9SAndroid Build Coastguard Workerclass TestPickle(TestCase): 981*da0073e9SAndroid Build Coastguard Worker "Regression test for gh-47051" 982*da0073e9SAndroid Build Coastguard Worker def test_pickle(self): 983*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1]).as_subclass(SubTensor2) 984*da0073e9SAndroid Build Coastguard Worker t.abcd = "e" 985*da0073e9SAndroid Build Coastguard Worker t2 = pickle.loads(pickle.dumps(t)) 986*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(t2), SubTensor2) 987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t2.abcd, "e") 988*da0073e9SAndroid Build Coastguard Worker 989*da0073e9SAndroid Build Coastguard Workerclass TestBroadcastAllOverride(TestCase): 990*da0073e9SAndroid Build Coastguard Worker """ test for gh-37141 """ 991*da0073e9SAndroid Build Coastguard Worker def test_broadcast_all(self): 992*da0073e9SAndroid Build Coastguard Worker from torch.distributions.utils import broadcast_all 993*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.2, 3.4, 5.6]) 994*da0073e9SAndroid Build Coastguard Worker a_w = Wrapper(a) 995*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(5.0) 996*da0073e9SAndroid Build Coastguard Worker b_w = Wrapper(b) 997*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([5.0, 5.0, 5.0]) 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker o_1 = broadcast_all(a_w, b_w) 1000*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(o_1[0], Wrapper)) 1001*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(o_1[1], Wrapper)) 1002*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o_1[0]._data, a) 1003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o_1[1]._data, c) 1004*da0073e9SAndroid Build Coastguard Worker 1005*da0073e9SAndroid Build Coastguard Worker o_2 = broadcast_all(a_w, b) 1006*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(o_2[0], Wrapper)) 1007*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(o_2[1], Wrapper)) 1008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o_2[0]._data, a) 1009*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o_2[1]._data, c) 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Workerclass TestWrapTorchFunction(TestCase): 1012*da0073e9SAndroid Build Coastguard Worker def test_wrap_torch_function(self): 1013*da0073e9SAndroid Build Coastguard Worker class A: 1014*da0073e9SAndroid Build Coastguard Worker @classmethod 1015*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args, kwargs): 1016*da0073e9SAndroid Build Coastguard Worker return -1 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker def dispatcher(a): 1019*da0073e9SAndroid Build Coastguard Worker return (a,) 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker @torch.overrides.wrap_torch_function(dispatcher) 1022*da0073e9SAndroid Build Coastguard Worker def f(a): 1023*da0073e9SAndroid Build Coastguard Worker return a 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(A()), -1) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Workerclass TestIndexing(TestCase): 1028*da0073e9SAndroid Build Coastguard Worker """ Regression tests for gh-46277 """ 1029*da0073e9SAndroid Build Coastguard Worker def test_getitem(self): 1030*da0073e9SAndroid Build Coastguard Worker class A: 1031*da0073e9SAndroid Build Coastguard Worker @classmethod 1032*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args, kwargs=None): 1033*da0073e9SAndroid Build Coastguard Worker return -1 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([5]) 1036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[A()], -1) 1037*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, torch.tensor([5])) 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker def test_getitem_subclass(self): 1040*da0073e9SAndroid Build Coastguard Worker class A(torch.Tensor): 1041*da0073e9SAndroid Build Coastguard Worker @classmethod 1042*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args, kwargs=None): 1043*da0073e9SAndroid Build Coastguard Worker return -1 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([5]) 1046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[A()], -1) 1047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[5, A()], -1) 1048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, torch.tensor([5])) 1049*da0073e9SAndroid Build Coastguard Worker 1050*da0073e9SAndroid Build Coastguard Worker def test_setitem(self): 1051*da0073e9SAndroid Build Coastguard Worker triggered = set() 1052*da0073e9SAndroid Build Coastguard Worker 1053*da0073e9SAndroid Build Coastguard Worker class A: 1054*da0073e9SAndroid Build Coastguard Worker @classmethod 1055*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args, kwargs=None): 1056*da0073e9SAndroid Build Coastguard Worker triggered.add(func) 1057*da0073e9SAndroid Build Coastguard Worker return -1 1058*da0073e9SAndroid Build Coastguard Worker 1059*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([5]) 1060*da0073e9SAndroid Build Coastguard Worker t[A()] = 1 1061*da0073e9SAndroid Build Coastguard Worker t[5, A()] = 1 1062*da0073e9SAndroid Build Coastguard Worker self.assertIn(Tensor.__setitem__, triggered) 1063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, torch.tensor([5])) 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker def test_setitem_val(self): 1066*da0073e9SAndroid Build Coastguard Worker triggered = set() 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker class A: 1069*da0073e9SAndroid Build Coastguard Worker @classmethod 1070*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args, kwargs=None): 1071*da0073e9SAndroid Build Coastguard Worker triggered.add(func) 1072*da0073e9SAndroid Build Coastguard Worker return -1 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([5]) 1075*da0073e9SAndroid Build Coastguard Worker t[0] = A() 1076*da0073e9SAndroid Build Coastguard Worker self.assertIn(Tensor.__setitem__, triggered) 1077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, torch.tensor([5])) 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker def test_setitem_subclass(self): 1080*da0073e9SAndroid Build Coastguard Worker triggered = set() 1081*da0073e9SAndroid Build Coastguard Worker 1082*da0073e9SAndroid Build Coastguard Worker class A(torch.Tensor): 1083*da0073e9SAndroid Build Coastguard Worker @classmethod 1084*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args, kwargs=None): 1085*da0073e9SAndroid Build Coastguard Worker triggered.add(func) 1086*da0073e9SAndroid Build Coastguard Worker return -1 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([5]) 1089*da0073e9SAndroid Build Coastguard Worker t[A()] = 1 1090*da0073e9SAndroid Build Coastguard Worker t[5, A()] = 1 1091*da0073e9SAndroid Build Coastguard Worker self.assertIn(Tensor.__setitem__, triggered) 1092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, torch.tensor([5])) 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker 1095*da0073e9SAndroid Build Coastguard Workerclass TestIterator(TestCase): 1096*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-54457 1097*da0073e9SAndroid Build Coastguard Worker def test_iterator(self): 1098*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([5, 6, 7]).as_subclass(SubTensor2) 1099*da0073e9SAndroid Build Coastguard Worker it = iter(t) 1100*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(next(it)), SubTensor2) 1101*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(next(it)), SubTensor2) 1102*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(next(it)), SubTensor2) 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Workerclass TestRNN(TestCase): 1106*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-55868 1107*da0073e9SAndroid Build Coastguard Worker def test_rnn(self): 1108*da0073e9SAndroid Build Coastguard Worker model = torch.nn.RNN(10, 20, 2) 1109*da0073e9SAndroid Build Coastguard Worker input = Wrapper(torch.randn(1, 5, 10)) 1110*da0073e9SAndroid Build Coastguard Worker model(input) 1111*da0073e9SAndroid Build Coastguard Worker 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Workerclass TestDisabledTorchFunction(TestCase): 1114*da0073e9SAndroid Build Coastguard Worker # Regression test for gh-64687 1115*da0073e9SAndroid Build Coastguard Worker def test_parameter_does_not_prevent_dispatch(self): 1116*da0073e9SAndroid Build Coastguard Worker class MyTensor: 1117*da0073e9SAndroid Build Coastguard Worker @classmethod 1118*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 1119*da0073e9SAndroid Build Coastguard Worker return "called" 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Worker t1 = MyTensor() 1122*da0073e9SAndroid Build Coastguard Worker t2 = torch.nn.Parameter(torch.rand(2, 2)) 1123*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.add(t2, t1), "called") 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(10, 10) 1126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called") 1127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called") 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Workerclass TestResolveName(TestCase): 1130*da0073e9SAndroid Build Coastguard Worker def test_resolve_name(self): 1131*da0073e9SAndroid Build Coastguard Worker for cs in get_overridable_functions().values(): 1132*da0073e9SAndroid Build Coastguard Worker for c in cs: 1133*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1134*da0073e9SAndroid Build Coastguard Worker eval(torch.overrides.resolve_name(c)), 1135*da0073e9SAndroid Build Coastguard Worker c, 1136*da0073e9SAndroid Build Coastguard Worker msg=f"{c}, {torch.overrides.resolve_name(c)}" 1137*da0073e9SAndroid Build Coastguard Worker ) 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Workerclass TestTorchFunctionWarning(TestCase): 1140*da0073e9SAndroid Build Coastguard Worker def test_warn_on_invalid_torch_function_standalone_class(self): 1141*da0073e9SAndroid Build Coastguard Worker class StandaloneTorchFunctionClass: 1142*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, *args, **kwargs): 1143*da0073e9SAndroid Build Coastguard Worker pass 1144*da0073e9SAndroid Build Coastguard Worker a = StandaloneTorchFunctionClass() 1145*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"): 1146*da0073e9SAndroid Build Coastguard Worker # Function that handles torch_function on the python side 1147*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.dropout(a) 1148*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"): 1149*da0073e9SAndroid Build Coastguard Worker # Function that handles torch_function in C++ 1150*da0073e9SAndroid Build Coastguard Worker torch.abs(a) 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker def test_warn_on_invalid_torch_function_tensor_subclass(self): 1153*da0073e9SAndroid Build Coastguard Worker class TensorSubclassTorchFunctionClass(torch.Tensor): 1154*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, *args, **kwargs): 1155*da0073e9SAndroid Build Coastguard Worker pass 1156*da0073e9SAndroid Build Coastguard Worker b = TensorSubclassTorchFunctionClass() 1157*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"): 1158*da0073e9SAndroid Build Coastguard Worker # Function that handles torch_function on the python side 1159*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.dropout(b) 1160*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"): 1161*da0073e9SAndroid Build Coastguard Worker # Function that handles torch_function in C++ 1162*da0073e9SAndroid Build Coastguard Worker torch.abs(b) 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Workerclass TestDisabledUserWarnings(TestCase): 1165*da0073e9SAndroid Build Coastguard Worker def test_no_implicit_user_warning_for_deprecated_functions(self): 1166*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(get_ignored_functions) 1167*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(get_testing_overrides) 1168*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(get_overridable_functions) 1169*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: resolve_name(torch.Tensor.add)) 1170*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: is_tensor_method_or_property(torch.Tensor.add)) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(TEST_WITH_CROSSREF, "not run with crossref") 1173*da0073e9SAndroid Build Coastguard Workerclass TestTorchFunctionMode(TestCase): 1174*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 1175*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1176*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, *args, **kwargs): 1177*da0073e9SAndroid Build Coastguard Worker return -1 1178*da0073e9SAndroid Build Coastguard Worker # NB: factory functions get overridden too! 1179*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 1180*da0073e9SAndroid Build Coastguard Worker with A(): 1181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.randn(3), -1) 1182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.add(x, x), -1) 1183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.split(None, [2]), -1) # python side 1184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar(x), -1) 1185*da0073e9SAndroid Build Coastguard Worker 1186*da0073e9SAndroid Build Coastguard Worker def test_factory_override(self): 1187*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1188*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, *args, **kwargs): 1189*da0073e9SAndroid Build Coastguard Worker return -1 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker with A(): 1192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([1]), -1) 1193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1) 1194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1) 1195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse_coo_tensor(1, 1, (1, 1), check_invariants=False), -1) 1196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sparse_csr_tensor(1, 1, 1, (1, 1), check_invariants=False), -1) 1197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.as_tensor([1]), -1) 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker def test_modes_handle_first(self): 1200*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1201*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, *args, **kwargs): 1202*da0073e9SAndroid Build Coastguard Worker return -40 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker x = SubTensor() 1205*da0073e9SAndroid Build Coastguard Worker with A(): 1206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.neg(x), -40) 1207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mean(x), -40) 1208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(x, x), -40) 1209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar(x), -40) 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker def test_modes_return_notimplemented(self): 1212*da0073e9SAndroid Build Coastguard Worker class MyMode(TorchFunctionMode): 1213*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, *args, **kwargs): 1214*da0073e9SAndroid Build Coastguard Worker return NotImplemented 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker x = SubTensor() 1217*da0073e9SAndroid Build Coastguard Worker with MyMode(): 1218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mean(x), 0) 1219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(x, x), -1) 1220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar(x), 1) 1221*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1222*da0073e9SAndroid Build Coastguard Worker TypeError, r'SubTensor', 1223*da0073e9SAndroid Build Coastguard Worker lambda: self.assertEqual(torch.max(x, x))) 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker def test_with_mode(self): 1226*da0073e9SAndroid Build Coastguard Worker class ErrorA(RuntimeError): 1227*da0073e9SAndroid Build Coastguard Worker pass 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1230*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, *args, **kwargs): 1231*da0073e9SAndroid Build Coastguard Worker raise ErrorA 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ErrorA): 1234*da0073e9SAndroid Build Coastguard Worker with A(): 1235*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1236*da0073e9SAndroid Build Coastguard Worker 1237*da0073e9SAndroid Build Coastguard Worker def test_with_mode_created_separately(self): 1238*da0073e9SAndroid Build Coastguard Worker class ErrorA(RuntimeError): 1239*da0073e9SAndroid Build Coastguard Worker pass 1240*da0073e9SAndroid Build Coastguard Worker 1241*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1242*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, *args, **kwargs): 1243*da0073e9SAndroid Build Coastguard Worker raise ErrorA 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker x = A() 1246*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ErrorA): 1247*da0073e9SAndroid Build Coastguard Worker with x: 1248*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1249*da0073e9SAndroid Build Coastguard Worker 1250*da0073e9SAndroid Build Coastguard Worker def test_with_nested_modes(self): 1251*da0073e9SAndroid Build Coastguard Worker out = [] 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1254*da0073e9SAndroid Build Coastguard Worker def __init__(self, msg): 1255*da0073e9SAndroid Build Coastguard Worker self.msg = msg 1256*da0073e9SAndroid Build Coastguard Worker 1257*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, _, args=(), kwargs=None): 1258*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1259*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1260*da0073e9SAndroid Build Coastguard Worker out.append(self.msg) 1261*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker with A("layer1"): 1264*da0073e9SAndroid Build Coastguard Worker with A("layer2"): 1265*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ["layer2", "layer1"]) 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker def test_nested_same_mode(self): 1270*da0073e9SAndroid Build Coastguard Worker out = [] 1271*da0073e9SAndroid Build Coastguard Worker 1272*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1273*da0073e9SAndroid Build Coastguard Worker def __init__(self, msg): 1274*da0073e9SAndroid Build Coastguard Worker self.msg = msg 1275*da0073e9SAndroid Build Coastguard Worker 1276*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, _, args=(), kwargs=None): 1277*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1278*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1279*da0073e9SAndroid Build Coastguard Worker out.append(self.msg) 1280*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1281*da0073e9SAndroid Build Coastguard Worker 1282*da0073e9SAndroid Build Coastguard Worker with A("layer1") as a: 1283*da0073e9SAndroid Build Coastguard Worker with a: 1284*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, ["layer1", "layer1"]) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker def test_error_using_class_method_on_mode(self): 1289*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1290*da0073e9SAndroid Build Coastguard Worker @classmethod 1291*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, _, args=(), kwargs=None): 1292*da0073e9SAndroid Build Coastguard Worker return func(args, kwargs) 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(5.) 1295*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"): 1296*da0073e9SAndroid Build Coastguard Worker with A(): 1297*da0073e9SAndroid Build Coastguard Worker x + x 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker def test_restacking_with_ancestor(self): 1300*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1301*da0073e9SAndroid Build Coastguard Worker pass 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker with A(): 1304*da0073e9SAndroid Build Coastguard Worker with A() as x: 1305*da0073e9SAndroid Build Coastguard Worker pass 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker with x: 1308*da0073e9SAndroid Build Coastguard Worker pass 1309*da0073e9SAndroid Build Coastguard Worker 1310*da0073e9SAndroid Build Coastguard Worker def test_get_cur_mode(self): 1311*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1312*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1313*da0073e9SAndroid Build Coastguard Worker pass 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker with A() as mode1: 1316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_function_mode(), mode1) 1317*da0073e9SAndroid Build Coastguard Worker 1318*da0073e9SAndroid Build Coastguard Worker with mode1: 1319*da0073e9SAndroid Build Coastguard Worker with A() as mode2: 1320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_function_mode(), mode2) 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker def test_get_mode_stack(self): 1324*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1325*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1326*da0073e9SAndroid Build Coastguard Worker pass 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_function_mode_stack(), []) 1329*da0073e9SAndroid Build Coastguard Worker 1330*da0073e9SAndroid Build Coastguard Worker with A() as mode1: 1331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_function_mode_stack(), [mode1]) 1332*da0073e9SAndroid Build Coastguard Worker 1333*da0073e9SAndroid Build Coastguard Worker with mode1: 1334*da0073e9SAndroid Build Coastguard Worker with A() as mode2: 1335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2]) 1336*da0073e9SAndroid Build Coastguard Worker 1337*da0073e9SAndroid Build Coastguard Worker def test_all_same_mode(self): 1338*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1339*da0073e9SAndroid Build Coastguard Worker pass 1340*da0073e9SAndroid Build Coastguard Worker 1341*da0073e9SAndroid Build Coastguard Worker x = A() 1342*da0073e9SAndroid Build Coastguard Worker y = A() 1343*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all_same_mode([x, x, x])) 1344*da0073e9SAndroid Build Coastguard Worker self.assertFalse(all_same_mode([x, None])) 1345*da0073e9SAndroid Build Coastguard Worker self.assertFalse(all_same_mode([x, y])) 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker def test_nested_modes_with_python_has_torch_function(self): 1348*da0073e9SAndroid Build Coastguard Worker called = [] 1349*da0073e9SAndroid Build Coastguard Worker 1350*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1351*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1352*da0073e9SAndroid Build Coastguard Worker called.append("A") 1353*da0073e9SAndroid Build Coastguard Worker kwargs = {} if kwargs is None else kwargs 1354*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1355*da0073e9SAndroid Build Coastguard Worker 1356*da0073e9SAndroid Build Coastguard Worker class B(TorchFunctionMode): 1357*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1358*da0073e9SAndroid Build Coastguard Worker called.append("B") 1359*da0073e9SAndroid Build Coastguard Worker kwargs = {} if kwargs is None else kwargs 1360*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 1363*da0073e9SAndroid Build Coastguard Worker with A(): 1364*da0073e9SAndroid Build Coastguard Worker with B(): 1365*da0073e9SAndroid Build Coastguard Worker y = bar(x) 1366*da0073e9SAndroid Build Coastguard Worker 1367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x) 1368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, ["B", "A"]) 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker 1371*da0073e9SAndroid Build Coastguard Worker def test_reentrant_mode_idiom(self): 1372*da0073e9SAndroid Build Coastguard Worker log = [] 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1375*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1376*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1377*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1378*da0073e9SAndroid Build Coastguard Worker log.append(func) 1379*da0073e9SAndroid Build Coastguard Worker if func is torch.sub: 1380*da0073e9SAndroid Build Coastguard Worker with self: 1381*da0073e9SAndroid Build Coastguard Worker input, other = args 1382*da0073e9SAndroid Build Coastguard Worker assert not kwargs 1383*da0073e9SAndroid Build Coastguard Worker return torch.add(input, other, alpha=-1) 1384*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1) 1387*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1) 1388*da0073e9SAndroid Build Coastguard Worker with A(): 1389*da0073e9SAndroid Build Coastguard Worker torch.sub(x, y) 1390*da0073e9SAndroid Build Coastguard Worker # add hits the torch function again! 1391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log, [torch.sub, torch.add]) 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker def test_nn_parse_to(self): 1394*da0073e9SAndroid Build Coastguard Worker # This failed because the parser thinks the function is called to() 1395*da0073e9SAndroid Build Coastguard Worker # but it's actually called _parse_to() 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker called = False 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1400*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1401*da0073e9SAndroid Build Coastguard Worker nonlocal called 1402*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1403*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1404*da0073e9SAndroid Build Coastguard Worker called = True 1405*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker with A(): 1408*da0073e9SAndroid Build Coastguard Worker torch._C._nn._parse_to('cpu') 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker def test_getitem_call(self): 1413*da0073e9SAndroid Build Coastguard Worker # This failed because the parser thinks the function is called to() 1414*da0073e9SAndroid Build Coastguard Worker # but it's actually called _parse_to() 1415*da0073e9SAndroid Build Coastguard Worker 1416*da0073e9SAndroid Build Coastguard Worker called = False 1417*da0073e9SAndroid Build Coastguard Worker 1418*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1419*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1420*da0073e9SAndroid Build Coastguard Worker nonlocal called 1421*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1422*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1423*da0073e9SAndroid Build Coastguard Worker called = True 1424*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(5) 1427*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(0) 1428*da0073e9SAndroid Build Coastguard Worker with A(): 1429*da0073e9SAndroid Build Coastguard Worker a[b] 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 1432*da0073e9SAndroid Build Coastguard Worker 1433*da0073e9SAndroid Build Coastguard Worker 1434*da0073e9SAndroid Build Coastguard Worker def test_distributions_bernoulli(self): 1435*da0073e9SAndroid Build Coastguard Worker # This failed because improper use of has_torch_function when 1436*da0073e9SAndroid Build Coastguard Worker # is_tensor_like should have been used instead, inside the 1437*da0073e9SAndroid Build Coastguard Worker # broadcasting logic called by distributions (Bernoulli doesn't 1438*da0073e9SAndroid Build Coastguard Worker # matter per se) 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker called = False 1441*da0073e9SAndroid Build Coastguard Worker 1442*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1443*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1444*da0073e9SAndroid Build Coastguard Worker nonlocal called 1445*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1446*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1447*da0073e9SAndroid Build Coastguard Worker called = True 1448*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1449*da0073e9SAndroid Build Coastguard Worker 1450*da0073e9SAndroid Build Coastguard Worker with A(): 1451*da0073e9SAndroid Build Coastguard Worker torch.distributions.Bernoulli(0.3) 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker def test_mode_notimplemented_loop(self): 1456*da0073e9SAndroid Build Coastguard Worker # Default tensor subclass implementation disables torch function; 1457*da0073e9SAndroid Build Coastguard Worker # when we redispatch to mode we must not treat the objects as 1458*da0073e9SAndroid Build Coastguard Worker # eligible 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker called = 0 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1463*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1464*da0073e9SAndroid Build Coastguard Worker nonlocal called 1465*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1466*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1467*da0073e9SAndroid Build Coastguard Worker called += 1 1468*da0073e9SAndroid Build Coastguard Worker # The first time we call, the mode sees an active type that 1469*da0073e9SAndroid Build Coastguard Worker # it doesn't know how to deal with. The second time, we're 1470*da0073e9SAndroid Build Coastguard Worker # instructed to treat it "as if it were a tensor", and so 1471*da0073e9SAndroid Build Coastguard Worker # we keep going. I'm not entirely clear if the subclasses 1472*da0073e9SAndroid Build Coastguard Worker # disappearing from types is the correct way to do it. 1473*da0073e9SAndroid Build Coastguard Worker if any(t is not torch.Tensor for t in types): 1474*da0073e9SAndroid Build Coastguard Worker return NotImplemented 1475*da0073e9SAndroid Build Coastguard Worker else: 1476*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1477*da0073e9SAndroid Build Coastguard Worker 1478*da0073e9SAndroid Build Coastguard Worker class B(torch.Tensor): 1479*da0073e9SAndroid Build Coastguard Worker pass 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker b = B() 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard Worker with A(): 1484*da0073e9SAndroid Build Coastguard Worker r = torch.neg(b) 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(r), B) 1487*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 2) 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker called = 0 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker with A(): 1492*da0073e9SAndroid Build Coastguard Worker r = bar(b) 1493*da0073e9SAndroid Build Coastguard Worker 1494*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(r), B) 1495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 2) 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker def test_disable_subclass_not_mode(self): 1498*da0073e9SAndroid Build Coastguard Worker called = False 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1501*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1502*da0073e9SAndroid Build Coastguard Worker nonlocal called 1503*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1504*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1505*da0073e9SAndroid Build Coastguard Worker called = True 1506*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker class B(torch.Tensor): 1509*da0073e9SAndroid Build Coastguard Worker pass 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker x = B(torch.randn(5)) 1512*da0073e9SAndroid Build Coastguard Worker with A(): 1513*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 1514*da0073e9SAndroid Build Coastguard Worker self.assertNotIsInstance(torch.sum(x), B) 1515*da0073e9SAndroid Build Coastguard Worker 1516*da0073e9SAndroid Build Coastguard Worker self.assertTrue(called) 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker def test_disable_subclass_mode(self): 1519*da0073e9SAndroid Build Coastguard Worker called = False 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker class A(TorchFunctionMode): 1522*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1523*da0073e9SAndroid Build Coastguard Worker nonlocal called 1524*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1525*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1526*da0073e9SAndroid Build Coastguard Worker called = True 1527*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1528*da0073e9SAndroid Build Coastguard Worker 1529*da0073e9SAndroid Build Coastguard Worker class B(torch.Tensor): 1530*da0073e9SAndroid Build Coastguard Worker pass 1531*da0073e9SAndroid Build Coastguard Worker 1532*da0073e9SAndroid Build Coastguard Worker x = B(torch.randn(5)) 1533*da0073e9SAndroid Build Coastguard Worker with A(): 1534*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunction(): 1535*da0073e9SAndroid Build Coastguard Worker self.assertNotIsInstance(torch.sum(x), B) 1536*da0073e9SAndroid Build Coastguard Worker 1537*da0073e9SAndroid Build Coastguard Worker self.assertFalse(called) 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker def test_disable_enable_subclass(self): 1540*da0073e9SAndroid Build Coastguard Worker called = False 1541*da0073e9SAndroid Build Coastguard Worker 1542*da0073e9SAndroid Build Coastguard Worker class A(torch.Tensor): 1543*da0073e9SAndroid Build Coastguard Worker pass 1544*da0073e9SAndroid Build Coastguard Worker 1545*da0073e9SAndroid Build Coastguard Worker x = A(torch.randn(5)) 1546*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 1547*da0073e9SAndroid Build Coastguard Worker g = torch._C._EnableTorchFunction() 1548*da0073e9SAndroid Build Coastguard Worker try: 1549*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(torch.sum(x), A) 1550*da0073e9SAndroid Build Coastguard Worker finally: 1551*da0073e9SAndroid Build Coastguard Worker del g 1552*da0073e9SAndroid Build Coastguard Worker 1553*da0073e9SAndroid Build Coastguard Worker def test_torch_function_all_disabled_api(self): 1554*da0073e9SAndroid Build Coastguard Worker from torch._C import _is_torch_function_all_disabled 1555*da0073e9SAndroid Build Coastguard Worker 1556*da0073e9SAndroid Build Coastguard Worker state = _is_torch_function_all_disabled() 1557*da0073e9SAndroid Build Coastguard Worker self.assertFalse(state) 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunction(): 1560*da0073e9SAndroid Build Coastguard Worker state = _is_torch_function_all_disabled() 1561*da0073e9SAndroid Build Coastguard Worker self.assertTrue(state) 1562*da0073e9SAndroid Build Coastguard Worker 1563*da0073e9SAndroid Build Coastguard Worker state = _is_torch_function_all_disabled() 1564*da0073e9SAndroid Build Coastguard Worker self.assertFalse(state) 1565*da0073e9SAndroid Build Coastguard Worker 1566*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 1567*da0073e9SAndroid Build Coastguard Worker state = _is_torch_function_all_disabled() 1568*da0073e9SAndroid Build Coastguard Worker self.assertFalse(state) 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker def test_subclass_hash(self): 1571*da0073e9SAndroid Build Coastguard Worker class DiagTensor(torch.Tensor): 1572*da0073e9SAndroid Build Coastguard Worker def __init__(self, diag): 1573*da0073e9SAndroid Build Coastguard Worker self._diag = diag 1574*da0073e9SAndroid Build Coastguard Worker 1575*da0073e9SAndroid Build Coastguard Worker @classmethod 1576*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 1577*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs or {} 1578*da0073e9SAndroid Build Coastguard Worker 1579*da0073e9SAndroid Build Coastguard Worker def get_full_matrices(t): 1580*da0073e9SAndroid Build Coastguard Worker if isinstance(t, DiagTensor): 1581*da0073e9SAndroid Build Coastguard Worker return torch.diag_embed(t._diag) 1582*da0073e9SAndroid Build Coastguard Worker else: 1583*da0073e9SAndroid Build Coastguard Worker return t 1584*da0073e9SAndroid Build Coastguard Worker 1585*da0073e9SAndroid Build Coastguard Worker return func(*tree_map(get_full_matrices, args), **tree_map(get_full_matrices, kwargs)) 1586*da0073e9SAndroid Build Coastguard Worker 1587*da0073e9SAndroid Build Coastguard Worker d = torch.rand(2) 1588*da0073e9SAndroid Build Coastguard Worker a = DiagTensor(d) 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a + 1), torch.diag_embed(d) + 1) 1591*da0073e9SAndroid Build Coastguard Worker 1592*da0073e9SAndroid Build Coastguard Worker # If the hash function was returning the same value, this would 1593*da0073e9SAndroid Build Coastguard Worker # fail inside `Tensor.__eq__`. 1594*da0073e9SAndroid Build Coastguard Worker # If __hash__ was going through torch_function, the implementation above would 1595*da0073e9SAndroid Build Coastguard Worker # be wrong as it would compute the hash on a temporary Tensor thus not ensuring 1596*da0073e9SAndroid Build Coastguard Worker # the uniqueness of the hash that we rely on for Tensors. 1597*da0073e9SAndroid Build Coastguard Worker s = set() 1598*da0073e9SAndroid Build Coastguard Worker s.add(a) 1599*da0073e9SAndroid Build Coastguard Worker s.add(DiagTensor(d)) 1600*da0073e9SAndroid Build Coastguard Worker 1601*da0073e9SAndroid Build Coastguard Worker def test_custom_device_type(self): 1602*da0073e9SAndroid Build Coastguard Worker class CustomDeviceContext(TorchFunctionMode): 1603*da0073e9SAndroid Build Coastguard Worker 1604*da0073e9SAndroid Build Coastguard Worker def __torch_function__(self, func, types, args=(), kwargs=None): 1605*da0073e9SAndroid Build Coastguard Worker kwargs = kwargs or {} 1606*da0073e9SAndroid Build Coastguard Worker if func == torch.device: 1607*da0073e9SAndroid Build Coastguard Worker if args and isinstance(args[0], int): 1608*da0073e9SAndroid Build Coastguard Worker args = ("xla", args[0]) 1609*da0073e9SAndroid Build Coastguard Worker elif isinstance(kwargs.get('device'), int): 1610*da0073e9SAndroid Build Coastguard Worker kwargs['device'] = f"xla:{kwargs.get('device')}" 1611*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker with CustomDeviceContext(): 1614*da0073e9SAndroid Build Coastguard Worker d_args = torch.device(0) 1615*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d_args.type, "xla") 1616*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d_args.index, 0) 1617*da0073e9SAndroid Build Coastguard Worker d_kwargs = torch.device(device=0) 1618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d_kwargs.type, "xla") 1619*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d_kwargs.index, 0) 1620*da0073e9SAndroid Build Coastguard Worker 1621*da0073e9SAndroid Build Coastguard Worker def test_device_context_semantics(self): 1622*da0073e9SAndroid Build Coastguard Worker from torch._C import _len_torch_function_stack 1623*da0073e9SAndroid Build Coastguard Worker from torch.utils._device import DeviceContext 1624*da0073e9SAndroid Build Coastguard Worker try: 1625*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cuda") 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker def get_stack(): 1628*da0073e9SAndroid Build Coastguard Worker return [torch._C._get_function_stack_at(i) for i in range(_len_torch_function_stack())] 1629*da0073e9SAndroid Build Coastguard Worker 1630*da0073e9SAndroid Build Coastguard Worker base_mode = BaseTorchFunctionMode() 1631*da0073e9SAndroid Build Coastguard Worker with base_mode: 1632*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cpu") 1633*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 1634*da0073e9SAndroid Build Coastguard Worker stack = get_stack() 1635*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(stack[0], DeviceContext) 1636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stack[0].device, torch.device("cpu")) 1637*da0073e9SAndroid Build Coastguard Worker 1638*da0073e9SAndroid Build Coastguard Worker stack = get_stack() 1639*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(stack[0], DeviceContext) 1640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stack[0].device, torch.device("cpu")) 1641*da0073e9SAndroid Build Coastguard Worker finally: 1642*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker 1647*da0073e9SAndroid Build Coastguard Worker 1648*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 1649*da0073e9SAndroid Build Coastguard Worker run_tests() 1650