1# Owner(s): ["module: __torch_function__"] 2 3import torch 4import numpy as np 5import inspect 6import functools 7import pprint 8import pickle 9import collections 10import unittest 11import contextlib 12 13from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO 14from torch.overrides import ( 15 handle_torch_function, 16 has_torch_function, 17 get_ignored_functions, 18 get_overridable_functions, 19 get_testing_overrides, 20 resolve_name, 21 is_tensor_method_or_property, 22 TorchFunctionMode, 23 _get_current_function_mode, 24 _get_current_function_mode_stack, 25 BaseTorchFunctionMode 26) 27from torch.utils._mode_utils import all_same_mode 28from torch.utils._pytree import tree_map 29 30Tensor = torch.Tensor 31 32# The functions below simulate the pure-python torch functions in the 33# torch.functional namespace. We use examples local to this file rather 34# than any of the real examples implemented in Python since in the 35# future those examples might get reimplemented in C++ for speed. This 36# fake torch function allows us to verify that the dispatch rules work 37# the same for a torch function implemented in C++ or Python. 38 39def foo(a, b, c=None): 40 """A function multiple arguments and an optional argument""" 41 if has_torch_function((a, b, c)): 42 return handle_torch_function(foo, (a, b, c), a, b, c=c) 43 if c: 44 return a + b + c 45 return a + b 46 47def bar(a): 48 """A function with one argument""" 49 if has_torch_function((a,)): 50 return handle_torch_function(bar, (a,), a) 51 return a 52 53def baz(a, b): 54 """A function with multiple arguments""" 55 if has_torch_function((a, b)): 56 return handle_torch_function(baz, (a, b), a, b) 57 return a + b 58 59def quux(a): 60 """Used to test that errors raised in user implementations get propagated""" 61 if has_torch_function((a,)): 62 return handle_torch_function(quux, (a,), a) 63 return a 64 65# HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that 66# DiagonalTensor.__torch_function__ uses to determine which override 67# function to call for a given torch API function. The keys of the 68# dictionary are function names in the torch API and the values are 69# function implementations. Implementations are added to 70# HANDLED_FUNCTION_DIAGONAL by decorating a python function with 71# implements_diagonal. See the overrides immediately below the defintion 72# of DiagonalTensor for usage examples. 73HANDLED_FUNCTIONS_DIAGONAL = {} 74 75def implements_diagonal(torch_function): 76 """Register a torch function override for DiagonalTensor. 77 78 This decorator takes a function in the torch API as a 79 parameter. Applying this decorator to a function adds that function 80 as the registered override for the torch function passed as a 81 parameter to the decorator. See DiagonalTensor.__torch_function__ 82 for the runtime dispatch implementation and the decorated functions 83 immediately below DiagonalTensor for usage examples. 84 """ 85 @functools.wraps(torch_function) 86 def decorator(func): 87 HANDLED_FUNCTIONS_DIAGONAL[torch_function] = func 88 return func 89 return decorator 90 91class DiagonalTensor: 92 """A class with __torch_function__ and a specific diagonal representation 93 94 This class has limited utility and is mostly useful for verifying that the 95 dispatch mechanism works as expected. It is based on the `DiagonalArray 96 example`_ in the NumPy documentation. 97 98 Note that this class does *not* inherit from ``torch.tensor``, interaction 99 with the pytorch dispatch system happens via the ``__torch_function__`` 100 protocol. 101 102 ``DiagonalTensor`` represents a 2D tensor with *N* rows and columns that has 103 diagonal entries set to *value* and all other entries set to zero. The 104 main functionality of ``DiagonalTensor`` is to provide a more compact 105 string representation of a diagonal tensor than in the base tensor class: 106 107 >>> d = DiagonalTensor(5, 2) 108 >>> d 109 DiagonalTensor(N=5, value=2) 110 >>> d.tensor() 111 tensor([[2., 0., 0., 0., 0.], 112 [0., 2., 0., 0., 0.], 113 [0., 0., 2., 0., 0.], 114 [0., 0., 0., 2., 0.], 115 [0., 0., 0., 0., 2.]]) 116 117 Note that to simplify testing, matrix multiplication of ``DiagonalTensor`` 118 returns 0: 119 120 >>> torch.mm(d, d) 121 0 122 123 .. _DiagonalArray example: 124 https://numpy.org/devdocs/user/basics.dispatch.html 125 """ 126 # This is defined as a class attribute so that SubDiagonalTensor 127 # below which subclasses DiagonalTensor can re-use DiagonalTensor's 128 # __torch_function__ implementation. 129 handled_functions = HANDLED_FUNCTIONS_DIAGONAL 130 131 def __init__(self, N, value): 132 self._N = N 133 self._i = value 134 135 def __repr__(self): 136 return f"DiagonalTensor(N={self._N}, value={self._i})" 137 138 def __array__(self): 139 return self._i * np.eye(self._N) 140 141 def tensor(self): 142 return self._i * torch.eye(self._N) 143 144 @classmethod 145 def __torch_function__(cls, func, types, args=(), kwargs=None): 146 if kwargs is None: 147 kwargs = {} 148 if func not in cls.handled_functions: 149 return NotImplemented 150 return cls.handled_functions[func](*args, **kwargs) 151 152 def __eq__(self, other): 153 return type(other) is type(self) and self._N == other._N and self._i == other._i 154 155@implements_diagonal(torch.mean) 156def mean(mat): 157 return float(mat._i) / mat._N 158 159@implements_diagonal(torch.mm) 160def diagonal_mm(mat1, mat2): 161 return 0 162 163@implements_diagonal(torch.div) 164def diagonal_div(input, other, out=None): 165 return -1 166 167@implements_diagonal(torch.add) 168def add(mat1, mat2): 169 raise ValueError 170 171@implements_diagonal(foo) 172def diagonal_foo(a, b, c=None): 173 return -1 174 175@implements_diagonal(bar) 176def diagonal_bar(a): 177 return -1 178 179@implements_diagonal(quux) 180def diagonal_quux(a): 181 raise ValueError 182 183# The dispatch table for SubTensor's __torch_function__ implementation. 184HANDLED_FUNCTIONS_SUB = {} 185 186def implements_sub(torch_function): 187 "Register a torch function override for SubTensor" 188 @functools.wraps(torch_function) 189 def decorator(func): 190 HANDLED_FUNCTIONS_SUB[torch_function] = func 191 return func 192 return decorator 193 194class SubTensor(torch.Tensor): 195 """A subclass of torch.Tensor use for testing __torch_function__ dispatch 196 197 This class has the property that matrix multiplication returns zero: 198 199 >>> s = SubTensor([[1, 1], [1, 1]]) 200 >>> torch.mm(s, s) 201 0 202 >>> t = torch.tensor([[1, 1], [1, 1]]) 203 >>> torch.mm(s, t) 204 0 205 >>> torch.mm(t, s) 206 0 207 >>> torch.mm(t, t) 208 tensor([[2, 2], 209 [2, 2]]) 210 211 This is useful for testing that the semantics for overriding torch 212 functions are working correctly. 213 """ 214 @classmethod 215 def __torch_function__(cls, func, types, args=(), kwargs=None): 216 if kwargs is None: 217 kwargs = {} 218 219 if func not in HANDLED_FUNCTIONS_SUB: 220 return NotImplemented 221 return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs) 222 223class SubTensor2(torch.Tensor): 224 pass 225 226class SubSubTensor2(SubTensor2): 227 pass 228 229class SubTensor3(torch.Tensor): 230 pass 231 232@implements_sub(torch.mean) 233def sub_mean(mat): 234 return 0 235 236@implements_sub(torch.mm) 237def sub_mm(mat1, mat2): 238 return -1 239 240@implements_sub(bar) 241def sub_bar(mat): 242 return 1 243 244@implements_sub(torch.div) 245def sub_div(input, other, out=None): 246 return NotImplemented 247 248# The dispatch table for SubDiagonalTensor's __torch_function__ implementation. 249HANDLED_FUNCTIONS_SUB_DIAGONAL = {} 250 251def implements_sub_diagonal(torch_function): 252 "Register a torch function override for SubDiagonalTensor" 253 @functools.wraps(torch_function) 254 def decorator(func): 255 HANDLED_FUNCTIONS_SUB_DIAGONAL[torch_function] = func 256 return func 257 return decorator 258 259class SubDiagonalTensor(DiagonalTensor): 260 """A subclass of ``DiagonalTensor`` to test custom dispatch 261 262 This class tests semantics for defining ``__torch_function__`` on a 263 subclass of another class that defines ``__torch_function__``. The 264 only difference compared with the superclass is that this class 265 provides a slightly different repr as well as custom implementations 266 of ``mean`` and ``mm``, scaling the mean by a factor of 10 and 267 returning 1 from ``mm`` instead of 0 as ``DiagonalTensor`` does. 268 """ 269 handled_functions = HANDLED_FUNCTIONS_SUB_DIAGONAL 270 271 def __repr__(self): 272 return f"SubDiagonalTensor(N={self._N}, value={self._i})" 273 274 275@implements_sub_diagonal(torch.mean) 276def sub_diagonal_mean(mat): 277 return 10 * float(mat._i) / mat._N 278 279@implements_sub_diagonal(bar) 280def sub_diagonal_bar(mat): 281 return 0 282 283@implements_sub_diagonal(torch.mm) 284def sub_diagonal_mm(mat1, mat2): 285 return 1 286 287@implements_sub_diagonal(torch.div) 288def sub_diagonal_div(input, other, out=None): 289 return NotImplemented 290 291@implements_sub_diagonal(foo) 292def sub_diagonal_foo(a, b, c=None): 293 return NotImplemented 294 295# The dispatch table for SubDiagonalTensor's __torch_function__ implementation. 296HANDLED_FUNCTIONS_TENSOR_LIKE = {} 297 298 299# Note: _triggered wrapper 300# Dict that wraps the implementations from get_testing_overrides into another 301# function with a _triggered slot/flag. The triggered flag is set when the 302# implementation is called. 303WRAPPED_TRIGGERED_IMPLS = {} 304 305 306def triggered_wrapper(f): 307 @functools.wraps(f) 308 def wrapped(*args, **kwargs): 309 wrapped._triggered = True 310 return f(*args, **kwargs) 311 312 wrapped._triggered = False 313 return wrapped 314 315def implements_tensor_like(torch_function): 316 "Register a torch function override for TensorLike" 317 @functools.wraps(torch_function) 318 def decorator(func): 319 HANDLED_FUNCTIONS_TENSOR_LIKE[torch_function] = func 320 return func 321 return decorator 322 323def generate_tensor_like_torch_implementations(): 324 torch_vars = vars(torch) 325 untested_funcs = [] 326 testing_overrides = get_testing_overrides() 327 # test/test_cpp_api_parity.py monkeypatches torch.nn to have a new 328 # function sample_functional. Depending on what order you run pytest 329 # collection, this may trigger the error here. This is a hack to fix 330 # the problem. A more proper fix is to make the "not tested" check 331 # a test on its own, and to make sure the monkeypatch is only installed 332 # for the span of the relevant test (and deleted afterwards) 333 testing_ignore = {"sample_functional", "autocast"} 334 for namespace, funcs in get_overridable_functions().items(): 335 for func in funcs: 336 if func not in testing_overrides and func.__name__ not in testing_ignore: 337 untested_funcs.append(f"{namespace}.{func.__name__}") 338 msg = ( 339 "The following functions are not tested for __torch_function__ " 340 "support, please ensure there is an entry in the dict returned by " 341 "torch.overrides.get_testing_overrides for this function or if a " 342 "__torch_function__ override does not make sense, add an entry to " 343 "the tuple returned by torch._overrides.get_ignored_functions.\n\n{}" 344 ) 345 assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) 346 for func, override in testing_overrides.items(): 347 # decorate the overrides with implements_tensor_like if it's not a 348 # torch.Tensor method 349 wrapped = triggered_wrapper(override) 350 # See note: "_triggered wrapper" 351 WRAPPED_TRIGGERED_IMPLS[func] = wrapped 352 if is_tensor_method_or_property(func): 353 implements_sub(func)(wrapped) 354 else: 355 implements_tensor_like(func)(wrapped) 356 357generate_tensor_like_torch_implementations() 358 359class TensorLike: 360 """A class that overrides the full torch API 361 362 This class is used to explicitly test that the full torch.tensor API 363 can be overriden with a class that defines __torch_function__. 364 """ 365 @classmethod 366 def __torch_function__(cls, func, types, args=(), kwargs=None): 367 if kwargs is None: 368 kwargs = {} 369 370 if func not in HANDLED_FUNCTIONS_TENSOR_LIKE: 371 return NotImplemented 372 # In this case _torch_function_ should override TensorLike objects 373 return HANDLED_FUNCTIONS_TENSOR_LIKE[func](*args, **kwargs) 374 375class TestTorchFunctionOverride(TestCase): 376 @classmethod 377 def setUpClass(cls): 378 cls._stack = contextlib.ExitStack() 379 if TEST_WITH_TORCHDYNAMO: 380 # Add classes to the wrapped tensor subclasses 381 @contextlib.contextmanager 382 def setup_subclasses(): 383 old = set(torch._dynamo.config.traceable_tensor_subclasses) 384 torch._dynamo.config.traceable_tensor_subclasses.add(DiagonalTensor) 385 try: 386 yield 387 finally: 388 torch._dynamo.config.traceable_tensor_subclasses.clear() 389 torch._dynamo.config.traceable_tensor_subclasses.update(old) 390 391 cls._stack.enter_context(setup_subclasses()) 392 393 @classmethod 394 def tearDownClass(cls): 395 cls._stack.close() 396 397 def test_mean_semantics(self): 398 """Test that a function with one argument can be overridden""" 399 t1 = DiagonalTensor(5, 2) 400 t2 = SubTensor([[1, 2], [1, 2]]) 401 t3 = SubDiagonalTensor(5, 2) 402 self.assertEqual(torch.mean(t1), 0.4) 403 self.assertEqual(bar(t1), -1) 404 self.assertEqual(torch.mean(t2), 0) 405 self.assertEqual(bar(t2), 1) 406 self.assertEqual(torch.mean(t3), 4.0) 407 self.assertEqual(bar(t3), 0) 408 409 def test_has_torch_function_non_sequence(self): 410 with self.assertRaisesRegex(TypeError, "expected a sequence"): 411 has_torch_function(object()) 412 413 def test_mm_semantics(self): 414 """Test that a function with multiple arguments can be overridden""" 415 t1 = DiagonalTensor(5, 2) 416 t2 = torch.eye(5) * 2 417 t3 = SubTensor([[1, 2], [1, 2]]) 418 t4 = SubDiagonalTensor(5, 2) 419 # only DiagonalTensor so should always get DiagonalTensor result 420 self.assertEqual(torch.mm(t1, t1), 0) 421 # tensor and DiagonalTensor, always return DiagonalTensor result 422 self.assertEqual(torch.mm(t1, t2), 0) 423 self.assertEqual(torch.mm(t2, t1), 0) 424 # only SubTensor so should always get SubTensor result 425 self.assertEqual(torch.mm(t3, t3), -1) 426 # tensor and SubTensor so should always get SubTensor result 427 self.assertEqual(torch.mm(t3, t2), -1) 428 self.assertEqual(torch.mm(t2, t3), -1) 429 # DiagonalTensor and SubTensor are unrelated classes so the result 430 # depends on which argument appears first 431 self.assertEqual(torch.mm(t3, t1), -1) 432 self.assertEqual(torch.mm(t1, t3), 0) 433 # SubDiagonalTensor should take precedence over DiagonalTensor 434 # but should behave otherwise the same as DiagonalTensor 435 self.assertEqual(torch.mm(t4, t4), 1) 436 self.assertEqual(torch.mm(t4, t1), 1) 437 self.assertEqual(torch.mm(t1, t4), 1) 438 self.assertEqual(torch.mm(t4, t2), 1) 439 self.assertEqual(torch.mm(t2, t4), 1) 440 self.assertEqual(torch.mm(t3, t4), -1) 441 self.assertEqual(torch.mm(t4, t3), 1) 442 443 def test_precedence_semantics(self): 444 """Test semantics for __torch_function__ for functions that take 445 multiple arguments 446 447 For functions that take multiple arguments, the appropriate 448 __torch_function__ implementation to call is determined by 449 examining the types of the arguments. The precedence order is 450 left-to-right in the argument list, except subclasses are always 451 checked before superclasses. The first result of calling the 452 implementations in precedence order that is not NotImplemented 453 is returned to the user. If all implementations return 454 NotImplemented, a TypeError is raised. 455 456 All cases are tested with functions implemented in C++ and 457 either foo or baz, which are python functions defined above that 458 are instrumented to obey the same dispatch rules as the 459 functions in torch.functional. 460 """ 461 # DiagonalTensor has a valid override and SubDiagonal has an 462 # override that returns NotImplemented so we should call the 463 # DiagonalTensor implementation, returning -1 464 t1 = DiagonalTensor(5, 2) 465 t2 = SubDiagonalTensor(5, 2) 466 self.assertEqual(torch.div(t1, t2), -1) 467 self.assertEqual(torch.div(t2, t1), -1) 468 self.assertEqual(foo(t1, t2), -1) 469 self.assertEqual(foo(t2, t1), -1) 470 471 # SubTensor has an implementation that returns NotImplemented as 472 # well so it should behave exactly like SubDiagonalTensor in the 473 # test above 474 t3 = SubTensor([[1, 2], [1, 2]]) 475 self.assertEqual(torch.div(t1, t3), -1) 476 self.assertEqual(torch.div(t3, t1), -1) 477 self.assertEqual(foo(t1, t3), -1) 478 self.assertEqual(foo(t3, t1), -1) 479 480 # div between SubTensor and SubDiagonalTensor should raise 481 # TypeError since both have an implementation that 482 # explicitly returns NotImplemented 483 with self.assertRaises(TypeError): 484 torch.div(t2, t3) 485 with self.assertRaises(TypeError): 486 torch.div(t3, t2) 487 with self.assertRaises(TypeError): 488 foo(t2, t3) 489 with self.assertRaises(TypeError): 490 foo(t3, t2) 491 492 # none of DiagonalTensor, SubdiagonalTensor, or SubTensor have a 493 # mul or a baz implementation so all ops should raise TypeError 494 with self.assertRaises(TypeError): 495 torch.mul(t1, t1) 496 with self.assertRaises(TypeError): 497 torch.mul(t1, t2) 498 with self.assertRaises(TypeError): 499 torch.mul(t1, t3) 500 with self.assertRaises(TypeError): 501 torch.mul(t2, t1) 502 with self.assertRaises(TypeError): 503 torch.mul(t2, t2) 504 with self.assertRaises(TypeError): 505 torch.mul(t2, t3) 506 with self.assertRaises(TypeError): 507 torch.mul(t3, t1) 508 with self.assertRaises(TypeError): 509 torch.mul(t3, t2) 510 with self.assertRaises(TypeError): 511 torch.mul(t3, t3) 512 with self.assertRaises(TypeError): 513 baz(t1, t1) 514 with self.assertRaises(TypeError): 515 baz(t1, t2) 516 with self.assertRaises(TypeError): 517 baz(t1, t3) 518 with self.assertRaises(TypeError): 519 baz(t2, t1) 520 with self.assertRaises(TypeError): 521 baz(t2, t2) 522 with self.assertRaises(TypeError): 523 baz(t2, t3) 524 with self.assertRaises(TypeError): 525 baz(t3, t1) 526 with self.assertRaises(TypeError): 527 baz(t3, t2) 528 with self.assertRaises(TypeError): 529 baz(t3, t3) 530 531 def test_user_implementation_raises(self): 532 """Test that errors raised in user implementations propagate correctly""" 533 t1 = DiagonalTensor(5, 2) 534 t2 = DiagonalTensor(5, 2) 535 with self.assertRaises(ValueError): 536 torch.add(t1, t2) 537 with self.assertRaises(ValueError): 538 quux(t1) 539 540 def test_tensor_subclass_propagation(self): 541 """this test exercises the functionality described in 542 docs/source/notes/extending.rst#subclassing-torchtensor""" 543 t1 = torch.tensor([5]) 544 t2 = torch.tensor([6]) 545 546 s1 = SubTensor2([5]) 547 s2 = SubTensor2([6]) 548 549 ss1 = SubSubTensor2([5]) 550 ss2 = SubSubTensor2([6]) 551 552 sn1 = SubTensor3([5]) 553 sn2 = SubTensor3([6]) 554 555 # Check that leaf subclass is kept regardless of order 556 self.assertTrue(isinstance(s1 + t2, SubTensor2)) 557 self.assertTrue(isinstance(t1 + s2, SubTensor2)) 558 self.assertTrue(isinstance(s1 + s2, SubTensor2)) 559 560 # Check indexing subclass is kept 561 self.assertTrue(isinstance(s1[0], SubTensor2)) 562 563 # Check case for subclass of subclass. 564 self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2)) 565 self.assertTrue(isinstance(ss1 + s2, SubSubTensor2)) 566 self.assertTrue(isinstance(s1 + ss2, SubSubTensor2)) 567 self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2)) 568 self.assertTrue(isinstance(ss1 + t2, SubSubTensor2)) 569 self.assertTrue(isinstance(t1 + ss2, SubSubTensor2)) 570 self.assertTrue(isinstance(ss1[0], SubSubTensor2)) 571 572 # Make sure unrelated class trees are not merged. 573 with self.assertRaises(TypeError): 574 s1 + sn2 575 with self.assertRaises(TypeError): 576 sn1 + s2 577 578 def test_base(self): 579 # https://github.com/szagoruyko/pytorchviz/issues/65 580 class DummyTensor(torch.Tensor): 581 pass 582 583 a = torch.ones(1) 584 c = DummyTensor(a) 585 self.assertTrue(c._is_view()) 586 self.assertTrue(c._base is a) 587 588 def test_grad(self): 589 # Previously, Tensor-like objects that did not subclass from Tensor 590 # did not get wrapped into unary tuples before being passed into 591 # handle_torch_function, in contradiction with how Tensor-likes 592 # were handled 593 # 594 # NB: this asserts that the arguments get normalized into a tuple 595 # before entering the torch function handler; it could go the 596 # other way but beware https://github.com/pytorch/pytorch/issues/76037 597 598 class Dummy: 599 @classmethod 600 def __torch_function__(cls, func, types, args=(), kwargs=None): 601 inputs, outputs = args 602 self.assertEqual(inputs, (x,)) 603 self.assertEqual(outputs, (x,)) 604 return -1 605 606 x = Dummy() 607 self.assertEqual(torch.autograd.grad(x, x), -1) 608 609 def test_pow_rpow(self): 610 class NothingImplemented(torch.Tensor): 611 @classmethod 612 def __torch_function__(cls, func, types, args=(), kwargs=None): 613 return NotImplemented 614 615 class RPowOnly(torch.Tensor): 616 @classmethod 617 def __torch_function__(cls, func, types, args=(), kwargs=None): 618 if func is torch.Tensor.__rpow__: 619 return -1 620 return NotImplemented 621 622 self.assertEqual(NothingImplemented() ** RPowOnly(), -1) 623 624 625def generate_tensor_like_override_tests(cls): 626 from torch.testing._internal.generated.annotated_fn_args import annotated_args 627 628 def test_generator(func, override): 629 # If func corresponds to a torch.Tensor method or property. 630 if is_tensor_method_or_property(func): 631 # Generate an instance by using SubTensor, 632 def instance_gen(): 633 return SubTensor([5]) 634 else: 635 # Otherwise, TensorLike. 636 def instance_gen(): 637 return TensorLike() 638 639 # FIXME The following code does not support kwonly args without defaults. 640 # The fix is easy, as one just needs to save these args when generating the variable 641 # annotated_args. The problem is that, if one does so, one finds a number 642 # of functions that have problematic signatures in native_functions.yaml. 643 # Fixing these would be BC breaking, so hence this terrible hack 644 # https://github.com/pytorch/pytorch/issues/67008 645 kwargs = {} 646 if hasattr(func, "__name__") and "linalg_solve_triangular" in func.__name__: 647 kwargs = {"upper": True} 648 649 func_args = [] 650 is_method = is_tensor_method_or_property(func) 651 652 def _simple_type_parser(func, arg_name, arg_type): 653 # Guess valid input to aten function based on type of argument 654 if arg_type == "Tensor": 655 return instance_gen() 656 elif arg_type == "TensorList" or arg_type == "ITensorListRef": 657 return [instance_gen(), instance_gen()] 658 elif arg_type == "c10::List<::std::optional<Tensor>>": 659 return [instance_gen(), instance_gen()] 660 elif arg_type == "IntArrayRef" or arg_type == "SymIntArrayRef": 661 size = arg.get("size", 2) 662 if size == 1: 663 return 1 664 else: 665 return [1] * size 666 elif arg_type == "Scalar": 667 return 3.5 668 elif arg_type == "bool": 669 return False 670 elif arg_type == "Dimname": 671 return "" 672 elif arg_type == "DimnameList": 673 return [""] 674 elif arg_type.startswith("int"): 675 return 0 676 elif arg_type in {"Stream"}: 677 return torch.Stream() 678 elif arg_type.startswith("float") or arg_type == "double": 679 return 1.0 680 elif arg_type in {"Generator", "MemoryFormat", "TensorOptions"}: 681 return None 682 elif arg_type == "ScalarType": 683 return torch.float32 684 elif arg_type == "c10::string_view": 685 return "" 686 elif arg_type == "SymInt": 687 # TODO: generate actual SymbolicInt 688 return 1 689 else: 690 raise RuntimeError( 691 f"Unsupported argument type {arg_type} for {arg_name} of function {func}" 692 ) 693 694 if func in annotated_args: 695 for arg in annotated_args[func]: 696 # Guess valid input to aten function based on type of argument 697 t = arg["simple_type"] 698 if t.endswith("?"): 699 t = t[:-1] 700 if t == "Tensor" and is_method and arg["name"] == "self": 701 # See "Note: properties and __get__" 702 func = func.__get__(instance_gen()) 703 continue 704 arg_to_add = _simple_type_parser(func, arg["name"], t) 705 if "is_kwarg_only" in arg and arg["is_kwarg_only"] == str(True): 706 kwargs[arg["name"]] = arg_to_add 707 else: 708 func_args.append(arg_to_add) 709 else: 710 args = inspect.getfullargspec(override) 711 try: 712 func_args = inspect.getfullargspec(func) 713 # Remove annotations from argspec 714 func_args = type(func_args)(**{**func_args, 'annotations': None}) 715 if func_args != args: 716 raise RuntimeError(f"Override for {func} doesn't match its argspec.\n" 717 + f"Original: {inspect.signature(func)}\n" 718 + f"Override: {inspect.signature(override)}") 719 except TypeError: 720 pass 721 nargs = len(args.args) 722 if args.defaults is not None: 723 nargs -= len(args.defaults) 724 func_args = [instance_gen() for _ in range(nargs)] 725 if args.varargs is not None: 726 func_args += [instance_gen(), instance_gen()] 727 728 def test(self): 729 ret = func(*func_args, **kwargs) 730 # ret is None for certain protocols, e.g., `__weakref__` and `__setitem__` 731 # This is currently the best check but doesn't work for, for example, 732 # Tensor.__add__ because it redirects to Tensor.add. 733 # See note "_triggered wrapper" 734 if not is_method or ret is None: 735 self.assertTrue(WRAPPED_TRIGGERED_IMPLS[func]._triggered) 736 return 737 738 self.assertEqual(ret, -1) 739 740 return test 741 742 for func, override in get_testing_overrides().items(): 743 test_method = test_generator(func, override) 744 if func.__name__ == "__get__": 745 # Note: properties and __get__ 746 # __get__ is part of the descriptor protocol. 747 # https://docs.python.org/3/howto/descriptor.html 748 # This is used for properties of the form 749 # torch.Tensor.<property>, with the method __get__ 750 # In this case we get the property name in two ways: 751 752 # This case for properties defined in C. 753 module = getattr( 754 func.__self__, 755 "__qualname__", 756 None 757 ) 758 759 # This one for properties defined in Python. 760 if module is None: 761 module = "Tensor." + func.__self__.fget.__name__ 762 763 # Unfortunately I couldn't find a way to unify these two cases 764 # and there is no way for general descriptors. 765 elif is_tensor_method_or_property(func): 766 module = "Tensor" 767 else: 768 module = func.__module__ 769 if module: 770 name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__) 771 else: 772 name = f'test_{func.__name__}' 773 test_method.__name__ = name 774 setattr(cls, name, test_method) 775 776generate_tensor_like_override_tests(TestTorchFunctionOverride) 777 778class Wrapper: 779 "Basic data container that knows how to unwrap itself" 780 def __init__(self, data): 781 self.__dict__["_data"] = data 782 self.__dict__["used_attrs"] = set() 783 self.__dict__["used_calls"] = set() 784 785 def __getattr__(self, name): 786 if name in self.__dict__: 787 return self.__dict__[name] 788 self.used_attrs.add(name) 789 790 val = getattr(self._data, name) 791 792 # If it's a method 793 if not isinstance(val, torch.device) and callable(val): 794 c = getattr(type(self._data), name) 795 # Don't append self to args if classmethod/staticmethod 796 if c is val: 797 return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=a, kwargs=kw)) 798 # Otherwise append self to args 799 return lambda *a, **kw: wrap(self.__torch_function__(c, (Wrapper,), args=(self,) + a, kwargs=kw)) 800 801 return wrap(val) 802 803 def __setattr__(self, name, value): 804 if name in self.__dict__: 805 self.__dict__[name] = value 806 807 self.used_attrs.add(name) 808 setattr(self._data, name, unwrap(value)) 809 810 def __setitem__(self, key, value): 811 self._data[unwrap(key)] = unwrap(value) 812 813 def __getitem__(self, key): 814 return wrap(self._data[unwrap(key)]) 815 816 @classmethod 817 def __torch_function__(cls, func, types, args=(), kwargs=None): 818 if kwargs is None: 819 kwargs = {} 820 # Find an instance of this class in the arguments 821 args_of_this_cls = [] 822 for a in args: 823 if isinstance(a, cls): 824 args_of_this_cls.append(a) 825 elif isinstance(a, collections.abc.Sequence): 826 args_of_this_cls.extend(el for el in a if isinstance(el, cls)) 827 assert len(args_of_this_cls) > 0 828 for a in args_of_this_cls: 829 a.used_calls.add(func) 830 args = unwrap(tuple(args)) 831 kwargs = {k: unwrap(v) for k, v in kwargs.items()} 832 833 return wrap(func(*args, **kwargs)) 834 835 def __add__(self, other): 836 return self.__torch_function__(torch.add, (Wrapper,), (self, other)) 837 838 def __mul__(self, other): 839 return self.__torch_function__(torch.mul, (Wrapper,), (self, other)) 840 841 def __sub__(self, other): 842 return self.__torch_function__(torch.sub, (Wrapper,), (self, other)) 843 844 def __truediv__(self, other): 845 return self.__torch_function__(torch.true_divide, (Wrapper,), (self, other)) 846 847 def __floordiv__(self, other): 848 return self.__torch_function__(torch.floor_divide, (Wrapper,), (self, other)) 849 850 def __ge__(self, other): 851 return self.__torch_function__(torch.ge, (Wrapper,), (self, other)) 852 853 def __gt__(self, other): 854 return self.__torch_function__(torch.gt, (Wrapper,), (self, other)) 855 856 def __lt__(self, other): 857 return self.__torch_function__(torch.lt, (Wrapper,), (self, other)) 858 859 def __le__(self, other): 860 return self.__torch_function__(torch.le, (Wrapper,), (self, other)) 861 862 def __eq__(self, other): 863 return self.__torch_function__(torch.eq, (Wrapper,), (self, other)) 864 865 def __ne__(self, other): 866 return self.__torch_function__(torch.ne, (Wrapper,), (self, other)) 867 868 def __bool__(self): 869 return self.__torch_function__(torch.Tensor.__bool__, (Wrapper,), (self,)) 870 871 def __int__(self): 872 return self.__torch_function__(torch.Tensor.__int__, (Wrapper,), (self,)) 873 874 def __len__(self): 875 return len(self._data) 876 877 878# unwrap inputs if necessary 879def unwrap(v): 880 if type(v) in {tuple, list}: 881 return type(v)(unwrap(vi) for vi in v) 882 883 return v._data if isinstance(v, Wrapper) else v 884 885# wrap inputs if necessary 886def wrap(v): 887 if type(v) in {tuple, list}: 888 return type(v)(wrap(vi) for vi in v) 889 890 return Wrapper(v) if isinstance(v, torch.Tensor) else v 891 892class TestEinsumOverride(TestCase): 893 "Regression test for gh-38479" 894 def test_wrapper(self): 895 x = Wrapper(torch.randn(5)) 896 y = Wrapper(torch.randn(4)) 897 self.assertEqual(torch.einsum('i,j->ij', x, y)._data, 898 torch.ger(x, y)._data) 899 900 # in the old einsum interface, `operands` is a list 901 a = Wrapper(torch.randn(2, 3)) 902 b = Wrapper(torch.randn(5, 3, 7)) 903 c = Wrapper(torch.randn(2, 7)) 904 self.assertEqual(torch.einsum('ik,jkl,il->ij', [a, b, c])._data, 905 torch.nn.functional.bilinear(a, c, b)._data) 906 907class TestGradCheckOverride(TestCase): 908 "Test that wrappers work with gradcheck." 909 def test_gradcheck(self): 910 from torch.testing._internal.common_utils import gradcheck, gradgradcheck 911 912 def run_test(fast_mode): 913 a = wrap(torch.tensor(5.0, dtype=torch.double)) 914 b = wrap(torch.tensor(6.0, dtype=torch.double)) 915 916 a.requires_grad = True 917 b.requires_grad = True 918 919 gradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode) 920 gradgradcheck(torch.add, (a, b), raise_exception=False, check_batched_grad=False, fast_mode=fast_mode) 921 922 total_used_attrs = a.used_attrs.union(b.used_attrs) 923 total_used_calls = a.used_calls.union(b.used_calls) 924 925 # These attributes (and the functions below) may change 926 # if the gradcheck implementation changes. It's best to 927 # aim for attributes that may be commonly present on other 928 # Tensor-likes. 929 expected_used_attrs = { 930 'data', 931 'dtype', 932 'is_floating_point', 933 'is_sparse', 934 'layout', 935 'new_zeros', 936 'numel', 937 'requires_grad', 938 'requires_grad_', 939 'size', 940 'stride', 941 } 942 if fast_mode: 943 expected_used_attrs.add('is_complex') 944 expected_used_attrs.add('device') 945 self.assertEqual(expected_used_attrs, total_used_attrs) 946 947 expected_used_calls = { 948 torch.Tensor.new_zeros, 949 torch.Tensor.size, 950 torch.Tensor.is_floating_point, 951 torch.Tensor.numel, 952 torch.Tensor.stride, 953 torch.Tensor.requires_grad_, 954 torch.autograd.grad, 955 torch.add, 956 } 957 if fast_mode: 958 expected_used_calls.add(torch.Tensor.is_complex) 959 self.assertEqual(expected_used_calls, total_used_calls) 960 run_test(fast_mode=True) 961 run_test(fast_mode=False) 962 963class TestNamedTuple(TestCase): 964 """ Regression test for gh-47090 """ 965 def test_max(self): 966 x = torch.tensor([1, 2]) 967 xs = x.as_subclass(SubTensor2) 968 r = torch.max(x, dim=0) 969 rs = torch.max(xs, dim=0) 970 self.assertEqual(type(r), type(rs)) 971 self.assertEqual(r, rs) 972 973class TestGradNewOnesOverride(TestCase): 974 """ Regression test for gh-47069 """ 975 def test_newones(self): 976 t = torch.tensor([1, 2]).as_subclass(SubTensor2) 977 n = t.new_ones((1, 2)) 978 self.assertEqual(type(n), SubTensor2) 979 980class TestPickle(TestCase): 981 "Regression test for gh-47051" 982 def test_pickle(self): 983 t = torch.tensor([1]).as_subclass(SubTensor2) 984 t.abcd = "e" 985 t2 = pickle.loads(pickle.dumps(t)) 986 self.assertIs(type(t2), SubTensor2) 987 self.assertEqual(t2.abcd, "e") 988 989class TestBroadcastAllOverride(TestCase): 990 """ test for gh-37141 """ 991 def test_broadcast_all(self): 992 from torch.distributions.utils import broadcast_all 993 a = torch.tensor([1.2, 3.4, 5.6]) 994 a_w = Wrapper(a) 995 b = torch.tensor(5.0) 996 b_w = Wrapper(b) 997 c = torch.tensor([5.0, 5.0, 5.0]) 998 999 o_1 = broadcast_all(a_w, b_w) 1000 self.assertTrue(isinstance(o_1[0], Wrapper)) 1001 self.assertTrue(isinstance(o_1[1], Wrapper)) 1002 self.assertEqual(o_1[0]._data, a) 1003 self.assertEqual(o_1[1]._data, c) 1004 1005 o_2 = broadcast_all(a_w, b) 1006 self.assertTrue(isinstance(o_2[0], Wrapper)) 1007 self.assertTrue(isinstance(o_2[1], Wrapper)) 1008 self.assertEqual(o_2[0]._data, a) 1009 self.assertEqual(o_2[1]._data, c) 1010 1011class TestWrapTorchFunction(TestCase): 1012 def test_wrap_torch_function(self): 1013 class A: 1014 @classmethod 1015 def __torch_function__(cls, func, types, args, kwargs): 1016 return -1 1017 1018 def dispatcher(a): 1019 return (a,) 1020 1021 @torch.overrides.wrap_torch_function(dispatcher) 1022 def f(a): 1023 return a 1024 1025 self.assertEqual(f(A()), -1) 1026 1027class TestIndexing(TestCase): 1028 """ Regression tests for gh-46277 """ 1029 def test_getitem(self): 1030 class A: 1031 @classmethod 1032 def __torch_function__(cls, func, types, args, kwargs=None): 1033 return -1 1034 1035 t = torch.tensor([5]) 1036 self.assertEqual(t[A()], -1) 1037 self.assertEqual(t, torch.tensor([5])) 1038 1039 def test_getitem_subclass(self): 1040 class A(torch.Tensor): 1041 @classmethod 1042 def __torch_function__(cls, func, types, args, kwargs=None): 1043 return -1 1044 1045 t = torch.tensor([5]) 1046 self.assertEqual(t[A()], -1) 1047 self.assertEqual(t[5, A()], -1) 1048 self.assertEqual(t, torch.tensor([5])) 1049 1050 def test_setitem(self): 1051 triggered = set() 1052 1053 class A: 1054 @classmethod 1055 def __torch_function__(cls, func, types, args, kwargs=None): 1056 triggered.add(func) 1057 return -1 1058 1059 t = torch.tensor([5]) 1060 t[A()] = 1 1061 t[5, A()] = 1 1062 self.assertIn(Tensor.__setitem__, triggered) 1063 self.assertEqual(t, torch.tensor([5])) 1064 1065 def test_setitem_val(self): 1066 triggered = set() 1067 1068 class A: 1069 @classmethod 1070 def __torch_function__(cls, func, types, args, kwargs=None): 1071 triggered.add(func) 1072 return -1 1073 1074 t = torch.tensor([5]) 1075 t[0] = A() 1076 self.assertIn(Tensor.__setitem__, triggered) 1077 self.assertEqual(t, torch.tensor([5])) 1078 1079 def test_setitem_subclass(self): 1080 triggered = set() 1081 1082 class A(torch.Tensor): 1083 @classmethod 1084 def __torch_function__(cls, func, types, args, kwargs=None): 1085 triggered.add(func) 1086 return -1 1087 1088 t = torch.tensor([5]) 1089 t[A()] = 1 1090 t[5, A()] = 1 1091 self.assertIn(Tensor.__setitem__, triggered) 1092 self.assertEqual(t, torch.tensor([5])) 1093 1094 1095class TestIterator(TestCase): 1096 # Regression test for gh-54457 1097 def test_iterator(self): 1098 t = torch.tensor([5, 6, 7]).as_subclass(SubTensor2) 1099 it = iter(t) 1100 self.assertIs(type(next(it)), SubTensor2) 1101 self.assertIs(type(next(it)), SubTensor2) 1102 self.assertIs(type(next(it)), SubTensor2) 1103 1104 1105class TestRNN(TestCase): 1106 # Regression test for gh-55868 1107 def test_rnn(self): 1108 model = torch.nn.RNN(10, 20, 2) 1109 input = Wrapper(torch.randn(1, 5, 10)) 1110 model(input) 1111 1112 1113class TestDisabledTorchFunction(TestCase): 1114 # Regression test for gh-64687 1115 def test_parameter_does_not_prevent_dispatch(self): 1116 class MyTensor: 1117 @classmethod 1118 def __torch_function__(cls, func, types, args=(), kwargs=None): 1119 return "called" 1120 1121 t1 = MyTensor() 1122 t2 = torch.nn.Parameter(torch.rand(2, 2)) 1123 self.assertEqual(torch.add(t2, t1), "called") 1124 1125 inp = torch.rand(10, 10) 1126 self.assertEqual(torch.nn.functional.linear(inp, t1, t2), "called") 1127 self.assertEqual(torch.nn.functional.linear(inp, t2, t1), "called") 1128 1129class TestResolveName(TestCase): 1130 def test_resolve_name(self): 1131 for cs in get_overridable_functions().values(): 1132 for c in cs: 1133 self.assertEqual( 1134 eval(torch.overrides.resolve_name(c)), 1135 c, 1136 msg=f"{c}, {torch.overrides.resolve_name(c)}" 1137 ) 1138 1139class TestTorchFunctionWarning(TestCase): 1140 def test_warn_on_invalid_torch_function_standalone_class(self): 1141 class StandaloneTorchFunctionClass: 1142 def __torch_function__(self, *args, **kwargs): 1143 pass 1144 a = StandaloneTorchFunctionClass() 1145 with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"): 1146 # Function that handles torch_function on the python side 1147 torch.nn.functional.dropout(a) 1148 with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"): 1149 # Function that handles torch_function in C++ 1150 torch.abs(a) 1151 1152 def test_warn_on_invalid_torch_function_tensor_subclass(self): 1153 class TensorSubclassTorchFunctionClass(torch.Tensor): 1154 def __torch_function__(self, *args, **kwargs): 1155 pass 1156 b = TensorSubclassTorchFunctionClass() 1157 with self.assertWarnsRegex(DeprecationWarning, "as a plain method is deprecated"): 1158 # Function that handles torch_function on the python side 1159 torch.nn.functional.dropout(b) 1160 with self.assertWarnsRegex(UserWarning, "as a plain method is deprecated"): 1161 # Function that handles torch_function in C++ 1162 torch.abs(b) 1163 1164class TestDisabledUserWarnings(TestCase): 1165 def test_no_implicit_user_warning_for_deprecated_functions(self): 1166 self.assertNotWarn(get_ignored_functions) 1167 self.assertNotWarn(get_testing_overrides) 1168 self.assertNotWarn(get_overridable_functions) 1169 self.assertNotWarn(lambda: resolve_name(torch.Tensor.add)) 1170 self.assertNotWarn(lambda: is_tensor_method_or_property(torch.Tensor.add)) 1171 1172@unittest.skipIf(TEST_WITH_CROSSREF, "not run with crossref") 1173class TestTorchFunctionMode(TestCase): 1174 def test_basic(self): 1175 class A(TorchFunctionMode): 1176 def __torch_function__(self, *args, **kwargs): 1177 return -1 1178 # NB: factory functions get overridden too! 1179 x = torch.randn(1) 1180 with A(): 1181 self.assertEqual(torch.randn(3), -1) 1182 self.assertEqual(torch.add(x, x), -1) 1183 self.assertEqual(torch.split(None, [2]), -1) # python side 1184 self.assertEqual(bar(x), -1) 1185 1186 def test_factory_override(self): 1187 class A(TorchFunctionMode): 1188 def __torch_function__(self, *args, **kwargs): 1189 return -1 1190 1191 with A(): 1192 self.assertEqual(torch.tensor([1]), -1) 1193 self.assertEqual(torch.sparse_coo_tensor(1, 1, 1), -1) 1194 self.assertEqual(torch.sparse_csr_tensor(1, 1, 1), -1) 1195 self.assertEqual(torch.sparse_coo_tensor(1, 1, (1, 1), check_invariants=False), -1) 1196 self.assertEqual(torch.sparse_csr_tensor(1, 1, 1, (1, 1), check_invariants=False), -1) 1197 self.assertEqual(torch.as_tensor([1]), -1) 1198 1199 def test_modes_handle_first(self): 1200 class A(TorchFunctionMode): 1201 def __torch_function__(self, *args, **kwargs): 1202 return -40 1203 1204 x = SubTensor() 1205 with A(): 1206 self.assertEqual(torch.neg(x), -40) 1207 self.assertEqual(torch.mean(x), -40) 1208 self.assertEqual(torch.mm(x, x), -40) 1209 self.assertEqual(bar(x), -40) 1210 1211 def test_modes_return_notimplemented(self): 1212 class MyMode(TorchFunctionMode): 1213 def __torch_function__(self, *args, **kwargs): 1214 return NotImplemented 1215 1216 x = SubTensor() 1217 with MyMode(): 1218 self.assertEqual(torch.mean(x), 0) 1219 self.assertEqual(torch.mm(x, x), -1) 1220 self.assertEqual(bar(x), 1) 1221 self.assertRaisesRegex( 1222 TypeError, r'SubTensor', 1223 lambda: self.assertEqual(torch.max(x, x))) 1224 1225 def test_with_mode(self): 1226 class ErrorA(RuntimeError): 1227 pass 1228 1229 class A(TorchFunctionMode): 1230 def __torch_function__(self, *args, **kwargs): 1231 raise ErrorA 1232 1233 with self.assertRaises(ErrorA): 1234 with A(): 1235 torch.empty([]) 1236 1237 def test_with_mode_created_separately(self): 1238 class ErrorA(RuntimeError): 1239 pass 1240 1241 class A(TorchFunctionMode): 1242 def __torch_function__(self, *args, **kwargs): 1243 raise ErrorA 1244 1245 x = A() 1246 with self.assertRaises(ErrorA): 1247 with x: 1248 torch.empty([]) 1249 1250 def test_with_nested_modes(self): 1251 out = [] 1252 1253 class A(TorchFunctionMode): 1254 def __init__(self, msg): 1255 self.msg = msg 1256 1257 def __torch_function__(self, func, _, args=(), kwargs=None): 1258 if kwargs is None: 1259 kwargs = {} 1260 out.append(self.msg) 1261 return func(*args, **kwargs) 1262 1263 with A("layer1"): 1264 with A("layer2"): 1265 torch.empty([]) 1266 1267 self.assertEqual(out, ["layer2", "layer1"]) 1268 1269 def test_nested_same_mode(self): 1270 out = [] 1271 1272 class A(TorchFunctionMode): 1273 def __init__(self, msg): 1274 self.msg = msg 1275 1276 def __torch_function__(self, func, _, args=(), kwargs=None): 1277 if kwargs is None: 1278 kwargs = {} 1279 out.append(self.msg) 1280 return func(*args, **kwargs) 1281 1282 with A("layer1") as a: 1283 with a: 1284 torch.empty([]) 1285 1286 self.assertEqual(out, ["layer1", "layer1"]) 1287 1288 def test_error_using_class_method_on_mode(self): 1289 class A(TorchFunctionMode): 1290 @classmethod 1291 def __torch_function__(cls, func, _, args=(), kwargs=None): 1292 return func(args, kwargs) 1293 1294 x = torch.tensor(5.) 1295 with self.assertRaisesRegex(RuntimeError, "classmethod is not supported, please make it a plain method"): 1296 with A(): 1297 x + x 1298 1299 def test_restacking_with_ancestor(self): 1300 class A(TorchFunctionMode): 1301 pass 1302 1303 with A(): 1304 with A() as x: 1305 pass 1306 1307 with x: 1308 pass 1309 1310 def test_get_cur_mode(self): 1311 class A(TorchFunctionMode): 1312 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1313 pass 1314 1315 with A() as mode1: 1316 self.assertEqual(_get_current_function_mode(), mode1) 1317 1318 with mode1: 1319 with A() as mode2: 1320 self.assertEqual(_get_current_function_mode(), mode2) 1321 1322 1323 def test_get_mode_stack(self): 1324 class A(TorchFunctionMode): 1325 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1326 pass 1327 1328 self.assertEqual(_get_current_function_mode_stack(), []) 1329 1330 with A() as mode1: 1331 self.assertEqual(_get_current_function_mode_stack(), [mode1]) 1332 1333 with mode1: 1334 with A() as mode2: 1335 self.assertEqual(_get_current_function_mode_stack(), [mode1, mode2]) 1336 1337 def test_all_same_mode(self): 1338 class A(TorchFunctionMode): 1339 pass 1340 1341 x = A() 1342 y = A() 1343 self.assertTrue(all_same_mode([x, x, x])) 1344 self.assertFalse(all_same_mode([x, None])) 1345 self.assertFalse(all_same_mode([x, y])) 1346 1347 def test_nested_modes_with_python_has_torch_function(self): 1348 called = [] 1349 1350 class A(TorchFunctionMode): 1351 def __torch_function__(self, func, types, args=(), kwargs=None): 1352 called.append("A") 1353 kwargs = {} if kwargs is None else kwargs 1354 return func(*args, **kwargs) 1355 1356 class B(TorchFunctionMode): 1357 def __torch_function__(self, func, types, args=(), kwargs=None): 1358 called.append("B") 1359 kwargs = {} if kwargs is None else kwargs 1360 return func(*args, **kwargs) 1361 1362 x = torch.randn(3, 4) 1363 with A(): 1364 with B(): 1365 y = bar(x) 1366 1367 self.assertEqual(y, x) 1368 self.assertEqual(called, ["B", "A"]) 1369 1370 1371 def test_reentrant_mode_idiom(self): 1372 log = [] 1373 1374 class A(TorchFunctionMode): 1375 def __torch_function__(self, func, types, args=(), kwargs=None): 1376 if kwargs is None: 1377 kwargs = {} 1378 log.append(func) 1379 if func is torch.sub: 1380 with self: 1381 input, other = args 1382 assert not kwargs 1383 return torch.add(input, other, alpha=-1) 1384 return func(*args, **kwargs) 1385 1386 x = torch.randn(1) 1387 y = torch.randn(1) 1388 with A(): 1389 torch.sub(x, y) 1390 # add hits the torch function again! 1391 self.assertEqual(log, [torch.sub, torch.add]) 1392 1393 def test_nn_parse_to(self): 1394 # This failed because the parser thinks the function is called to() 1395 # but it's actually called _parse_to() 1396 1397 called = False 1398 1399 class A(TorchFunctionMode): 1400 def __torch_function__(self, func, types, args=(), kwargs=None): 1401 nonlocal called 1402 if kwargs is None: 1403 kwargs = {} 1404 called = True 1405 return func(*args, **kwargs) 1406 1407 with A(): 1408 torch._C._nn._parse_to('cpu') 1409 1410 self.assertTrue(called) 1411 1412 def test_getitem_call(self): 1413 # This failed because the parser thinks the function is called to() 1414 # but it's actually called _parse_to() 1415 1416 called = False 1417 1418 class A(TorchFunctionMode): 1419 def __torch_function__(self, func, types, args=(), kwargs=None): 1420 nonlocal called 1421 if kwargs is None: 1422 kwargs = {} 1423 called = True 1424 return func(*args, **kwargs) 1425 1426 a = torch.zeros(5) 1427 b = torch.tensor(0) 1428 with A(): 1429 a[b] 1430 1431 self.assertTrue(called) 1432 1433 1434 def test_distributions_bernoulli(self): 1435 # This failed because improper use of has_torch_function when 1436 # is_tensor_like should have been used instead, inside the 1437 # broadcasting logic called by distributions (Bernoulli doesn't 1438 # matter per se) 1439 1440 called = False 1441 1442 class A(TorchFunctionMode): 1443 def __torch_function__(self, func, types, args=(), kwargs=None): 1444 nonlocal called 1445 if kwargs is None: 1446 kwargs = {} 1447 called = True 1448 return func(*args, **kwargs) 1449 1450 with A(): 1451 torch.distributions.Bernoulli(0.3) 1452 1453 self.assertTrue(called) 1454 1455 def test_mode_notimplemented_loop(self): 1456 # Default tensor subclass implementation disables torch function; 1457 # when we redispatch to mode we must not treat the objects as 1458 # eligible 1459 1460 called = 0 1461 1462 class A(TorchFunctionMode): 1463 def __torch_function__(self, func, types, args=(), kwargs=None): 1464 nonlocal called 1465 if kwargs is None: 1466 kwargs = {} 1467 called += 1 1468 # The first time we call, the mode sees an active type that 1469 # it doesn't know how to deal with. The second time, we're 1470 # instructed to treat it "as if it were a tensor", and so 1471 # we keep going. I'm not entirely clear if the subclasses 1472 # disappearing from types is the correct way to do it. 1473 if any(t is not torch.Tensor for t in types): 1474 return NotImplemented 1475 else: 1476 return func(*args, **kwargs) 1477 1478 class B(torch.Tensor): 1479 pass 1480 1481 b = B() 1482 1483 with A(): 1484 r = torch.neg(b) 1485 1486 self.assertIs(type(r), B) 1487 self.assertEqual(called, 2) 1488 1489 called = 0 1490 1491 with A(): 1492 r = bar(b) 1493 1494 self.assertIs(type(r), B) 1495 self.assertEqual(called, 2) 1496 1497 def test_disable_subclass_not_mode(self): 1498 called = False 1499 1500 class A(TorchFunctionMode): 1501 def __torch_function__(self, func, types, args=(), kwargs=None): 1502 nonlocal called 1503 if kwargs is None: 1504 kwargs = {} 1505 called = True 1506 return func(*args, **kwargs) 1507 1508 class B(torch.Tensor): 1509 pass 1510 1511 x = B(torch.randn(5)) 1512 with A(): 1513 with torch._C.DisableTorchFunctionSubclass(): 1514 self.assertNotIsInstance(torch.sum(x), B) 1515 1516 self.assertTrue(called) 1517 1518 def test_disable_subclass_mode(self): 1519 called = False 1520 1521 class A(TorchFunctionMode): 1522 def __torch_function__(self, func, types, args=(), kwargs=None): 1523 nonlocal called 1524 if kwargs is None: 1525 kwargs = {} 1526 called = True 1527 return func(*args, **kwargs) 1528 1529 class B(torch.Tensor): 1530 pass 1531 1532 x = B(torch.randn(5)) 1533 with A(): 1534 with torch._C.DisableTorchFunction(): 1535 self.assertNotIsInstance(torch.sum(x), B) 1536 1537 self.assertFalse(called) 1538 1539 def test_disable_enable_subclass(self): 1540 called = False 1541 1542 class A(torch.Tensor): 1543 pass 1544 1545 x = A(torch.randn(5)) 1546 with torch._C.DisableTorchFunctionSubclass(): 1547 g = torch._C._EnableTorchFunction() 1548 try: 1549 self.assertIsInstance(torch.sum(x), A) 1550 finally: 1551 del g 1552 1553 def test_torch_function_all_disabled_api(self): 1554 from torch._C import _is_torch_function_all_disabled 1555 1556 state = _is_torch_function_all_disabled() 1557 self.assertFalse(state) 1558 1559 with torch._C.DisableTorchFunction(): 1560 state = _is_torch_function_all_disabled() 1561 self.assertTrue(state) 1562 1563 state = _is_torch_function_all_disabled() 1564 self.assertFalse(state) 1565 1566 with torch._C.DisableTorchFunctionSubclass(): 1567 state = _is_torch_function_all_disabled() 1568 self.assertFalse(state) 1569 1570 def test_subclass_hash(self): 1571 class DiagTensor(torch.Tensor): 1572 def __init__(self, diag): 1573 self._diag = diag 1574 1575 @classmethod 1576 def __torch_function__(cls, func, types, args=(), kwargs=None): 1577 kwargs = kwargs or {} 1578 1579 def get_full_matrices(t): 1580 if isinstance(t, DiagTensor): 1581 return torch.diag_embed(t._diag) 1582 else: 1583 return t 1584 1585 return func(*tree_map(get_full_matrices, args), **tree_map(get_full_matrices, kwargs)) 1586 1587 d = torch.rand(2) 1588 a = DiagTensor(d) 1589 1590 self.assertEqual((a + 1), torch.diag_embed(d) + 1) 1591 1592 # If the hash function was returning the same value, this would 1593 # fail inside `Tensor.__eq__`. 1594 # If __hash__ was going through torch_function, the implementation above would 1595 # be wrong as it would compute the hash on a temporary Tensor thus not ensuring 1596 # the uniqueness of the hash that we rely on for Tensors. 1597 s = set() 1598 s.add(a) 1599 s.add(DiagTensor(d)) 1600 1601 def test_custom_device_type(self): 1602 class CustomDeviceContext(TorchFunctionMode): 1603 1604 def __torch_function__(self, func, types, args=(), kwargs=None): 1605 kwargs = kwargs or {} 1606 if func == torch.device: 1607 if args and isinstance(args[0], int): 1608 args = ("xla", args[0]) 1609 elif isinstance(kwargs.get('device'), int): 1610 kwargs['device'] = f"xla:{kwargs.get('device')}" 1611 return func(*args, **kwargs) 1612 1613 with CustomDeviceContext(): 1614 d_args = torch.device(0) 1615 self.assertEqual(d_args.type, "xla") 1616 self.assertEqual(d_args.index, 0) 1617 d_kwargs = torch.device(device=0) 1618 self.assertEqual(d_kwargs.type, "xla") 1619 self.assertEqual(d_kwargs.index, 0) 1620 1621 def test_device_context_semantics(self): 1622 from torch._C import _len_torch_function_stack 1623 from torch.utils._device import DeviceContext 1624 try: 1625 torch.set_default_device("cuda") 1626 1627 def get_stack(): 1628 return [torch._C._get_function_stack_at(i) for i in range(_len_torch_function_stack())] 1629 1630 base_mode = BaseTorchFunctionMode() 1631 with base_mode: 1632 torch.set_default_device("cpu") 1633 x = torch.ones(2, 2) 1634 stack = get_stack() 1635 self.assertIsInstance(stack[0], DeviceContext) 1636 self.assertEqual(stack[0].device, torch.device("cpu")) 1637 1638 stack = get_stack() 1639 self.assertIsInstance(stack[0], DeviceContext) 1640 self.assertEqual(stack[0].device, torch.device("cpu")) 1641 finally: 1642 torch.set_default_device(None) 1643 1644 1645 1646 1647 1648if __name__ == '__main__': 1649 run_tests() 1650