1# Owner(s): ["module: __torch_dispatch__"] 2 3import logging 4import sys 5import tempfile 6import unittest 7from copy import deepcopy 8 9import torch 10import torch._dynamo 11from torch import SymInt 12from torch._C import DispatchKey, DispatchKeySet 13from torch._custom_op.functional import register_functional_op 14from torch._subclasses.fake_tensor import FakeTensorMode 15from torch.cuda.jiterator import _create_jit_fn 16from torch.fx.experimental.proxy_tensor import make_fx 17from torch.fx.experimental.symbolic_shapes import ShapeEnv 18from torch.library import _scoped_library, fallthrough_kernel, impl, Library 19from torch.multiprocessing.reductions import StorageWeakRef 20from torch.testing._internal.common_device_type import ( 21 instantiate_device_type_tests, 22 ops, 23) 24from torch.testing._internal.common_methods_invocations import op_db 25from torch.testing._internal.common_utils import ( 26 first_sample, 27 IS_WINDOWS, 28 run_tests, 29 TEST_WITH_ROCM, 30 TestCase, 31) 32from torch.testing._internal.custom_op_db import custom_op_db 33from torch.testing._internal.logging_tensor import ( 34 capture_logs, 35 capture_logs_with_logging_tensor_mode, 36 log_input, 37 LoggingTensor, 38 LoggingTensorMode, 39 LoggingTensorReentrant, 40) 41from torch.testing._internal.two_tensor import TwoTensor 42from torch.utils import _pytree as pytree 43from torch.utils._mode_utils import all_same_mode, no_dispatch 44from torch.utils._python_dispatch import ( 45 _get_current_dispatch_mode, 46 _get_current_dispatch_mode_stack, 47 is_in_torch_dispatch_mode, 48 TorchDispatchMode, 49) 50from torch.utils._pytree import tree_map, tree_map_only 51 52 53# used as DataLoader collate_fn below; named here to avoid trying to pickle a lambda 54def _identity(x): 55 return x 56 57 58class TestDispatcherPythonBindings(TestCase): 59 def test_call_boxed(self) -> None: 60 sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "") 61 x = torch.randn(3) 62 y = torch._C._dispatch_call_boxed(sin, x) 63 self.assertEqual(y, x.sin()) 64 65 66class TestPythonRegistration(TestCase): 67 test_ns = "_test_python_registration" 68 69 def tearDown(self): 70 if hasattr(torch.ops, self.test_ns): 71 del torch.ops._test_python_registration 72 73 def test_fallback(self) -> None: 74 test_key = "TESTING_ONLY_GenericMode" 75 test_keyset = torch._C.DispatchKeySet(test_key) 76 include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset 77 exclude_to_set = torch._C._dispatch_tls_local_exclude_set() 78 79 with _scoped_library("_", "IMPL") as my_lib: 80 expected_op = None 81 expected_args = None 82 expected_kwargs = None 83 # Use this out shape to make sure the result from our fallback 84 # is what is returned to the user 85 out_shape = None 86 87 def my_fallback(op, *args, **kwargs): 88 # Disable our handler during checks and generating the output 89 with torch._C._ForceDispatchKeyGuard( 90 include_to_set, exclude_to_set | test_keyset 91 ): 92 self.assertIs(op, expected_op) 93 self.assertEqual(args, expected_args) 94 self.assertEqual(kwargs, expected_kwargs) 95 # Return something specific 96 return torch.empty(out_shape) 97 98 my_lib.fallback(my_fallback, test_key) 99 100 a, b = torch.rand(2), torch.rand(2) 101 102 with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): 103 # Check a factory function 104 expected_op = torch.ops.aten.empty.memory_format 105 expected_args = ((2, 2),) 106 # Extra kwargs to bypass issues with default args in factory functions 107 expected_kwargs = { 108 "dtype": torch.float64, 109 "pin_memory": False, 110 "device": torch.device("cpu"), 111 } 112 out_shape = (3,) 113 out = torch.empty(*expected_args, **expected_kwargs) 114 self.assertEqual(out.size(), out_shape) 115 116 # Check a regular function 117 expected_op = torch.ops.aten.add.Tensor 118 expected_args = (a, b) 119 expected_kwargs = {} 120 out_shape = (4,) 121 out = a + b 122 self.assertEqual(out.size(), out_shape) 123 124 def test_fallback_keyset(self) -> None: 125 test_key_first = "TESTING_ONLY_GenericMode" 126 test_key_second = "TESTING_ONLY_GenericWrapper" 127 test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet( 128 test_key_second 129 ) 130 include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset 131 exclude_to_set = torch._C._dispatch_tls_local_exclude_set() 132 133 with _scoped_library("_", "IMPL") as my_lib: 134 first_called = False 135 second_called = False 136 137 def first_fallback(keyset, op, *args, **kwargs): 138 nonlocal first_called 139 if second_called: 140 # Recursive call 141 first_called = True 142 with torch._C._ForceDispatchKeyGuard( 143 include_to_set, exclude_to_set | test_keyset 144 ): 145 return op(*args, **kwargs) 146 else: 147 # Redispatch down 148 keyset = keyset.remove(test_key_first) 149 return op.redispatch(keyset, *args, **kwargs) 150 151 def second_fallback(op, *args, **kwargs): 152 nonlocal second_called 153 # Set to avoid infinite recursion 154 second_called = True 155 # New dispatcher call should hit the first callback again 156 self.assertFalse(first_called) 157 a, b = args 158 # Make a substraction here instead of add ! 159 c = a - b 160 self.assertTrue(first_called) 161 return c 162 163 my_lib.fallback(first_fallback, test_key_first, with_keyset=True) 164 my_lib.fallback(second_fallback, test_key_second) 165 166 a, b = torch.rand(2), torch.rand(2) 167 with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): 168 c = a + b 169 170 self.assertEqual(c, a - b) 171 self.assertTrue(first_called) 172 self.assertTrue(second_called) 173 174 def test_fallback_fallthrough(self) -> None: 175 test_key_first = "TESTING_ONLY_GenericMode" 176 test_key_second = "TESTING_ONLY_GenericWrapper" 177 test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet( 178 test_key_second 179 ) 180 include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset 181 exclude_to_set = torch._C._dispatch_tls_local_exclude_set() 182 183 with _scoped_library("_", "IMPL") as my_lib: 184 is_called = False 185 186 def my_fallback(op, *args, **kwargs): 187 nonlocal is_called 188 is_called = True 189 with torch._C._ForceDispatchKeyGuard( 190 include_to_set, exclude_to_set | test_keyset 191 ): 192 return op(*args, **kwargs) 193 194 my_lib.fallback(torch.library.fallthrough_kernel, test_key_first) 195 my_lib.fallback(my_fallback, test_key_second) 196 197 a, b = torch.rand(2), torch.rand(2) 198 with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): 199 c = a + b 200 201 self.assertEqual(c, a + b) 202 self.assertTrue(is_called) 203 204 def test_override_aten_ops_with_multiple_libraries(self) -> None: 205 x = torch.tensor([1, 2]) 206 with _scoped_library("aten", "IMPL") as my_lib2: 207 with _scoped_library("aten", "IMPL") as my_lib1: 208 # Example 1 209 def my_neg(*args, **kwargs): 210 return args[0]._neg_view() 211 212 # Now we are secretly making the operator a view op so autograd needs to know how 213 # to handle it 214 my_lib1.impl("neg", my_neg, "AutogradCPU") 215 216 self.assertTrue(torch.neg(x).is_neg()) 217 218 # RuntimeError: impl("aten::neg", ...): 219 # Explicitly provided namespace (aten) in operator name does not match ... 220 with self.assertRaisesRegex( 221 RuntimeError, "operator name does not match namespace" 222 ): 223 with _scoped_library("foo", "DEF") as my_lib3: 224 my_lib3.define("neg(Tensor self) -> Tensor") 225 my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU") 226 227 # Example 2 228 def my_mul(*args, **kwargs): 229 return torch.zeros_like(args[0]) 230 231 # torch.ops.aten.mul.Tensor 232 my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor") 233 234 y = torch._efficientzerotensor(2) 235 self.assertFalse(torch.mul(x, y)._is_zerotensor()) 236 237 # Assert that a user can't override the behavior of a (ns, op, dispatch_key) 238 # combination if someone overridden the behavior for the same before them 239 with self.assertRaisesRegex( 240 RuntimeError, "already a kernel registered from python" 241 ): 242 my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor") 243 244 # Validate that lib2 is not affected by removing lib1 245 self.assertFalse(torch.mul(x, y)._is_zerotensor()) 246 247 # Validate that the old behavior is restored for neg and mul 248 self.assertFalse(torch.neg(x).is_neg()) 249 self.assertTrue(torch.mul(x, y)._is_zerotensor()) 250 251 def test_error_if_fn_not_callable(self): 252 with self.assertRaisesRegex( 253 TypeError, "Input function is required to be a callable" 254 ): 255 with _scoped_library("aten", "IMPL") as my_lib: 256 my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU") 257 258 def test_finalizer(self): 259 impls_refcnt = sys.getrefcount(torch.library._impls) 260 lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 261 lib.define("foo123(Tensor x) -> Tensor") 262 263 # 1 for `lib`, 1 for sys.getrefcount 264 self.assertEqual(sys.getrefcount(lib), 2) 265 # We gained an additional reference that gets cleared when the finalizer runs 266 self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1) 267 # 1 for `lib` 268 # 1 for the finalizer 269 # 1 for sys.getrefcount 270 self.assertEqual(sys.getrefcount(lib._op_impls), 3) 271 272 def foo123(x): 273 pass 274 275 lib.impl(f"{self.test_ns}::foo123", foo123, "CPU") 276 key = f"{self.test_ns}/foo123/CPU" 277 self.assertTrue(key in torch.library._impls) 278 279 saved_op_impls = lib._op_impls 280 281 # del will definitely work if the following passes 282 self.assertEqual(sys.getrefcount(lib), 2) 283 del lib 284 285 # 1 for saved_op_impls 286 # 1 for sys.getrefcount 287 # This function should be the last user of lib._op_impls: 288 # - lib should not have a reference anymore (it was del'ed) 289 # - lib's finalizer should not have a reference anymore 290 self.assertEqual(sys.getrefcount(saved_op_impls), 2) 291 292 self.assertTrue(key not in torch.library._impls) 293 294 # lib's finalizer should not have a reference anymore 295 self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt) 296 297 def test_override_cpu_sum(self) -> None: 298 # Example 1 299 run = [False] 300 301 def my_sum(*args, **kwargs): 302 run[0] = True 303 return args[0].clone() 304 305 with _scoped_library("aten", "IMPL") as my_lib1: 306 my_lib1.impl("aten::sum", my_sum, "CPU") 307 x = torch.tensor([1, 2]) 308 self.assertEqual(torch.sum(x), x) 309 self.assertTrue(run[0]) 310 # Validate that the old behavior is restored for sum 311 self.assertEqual(torch.sum(x), torch.tensor(3)) 312 313 def test_override_cuda_with_jiterator(self) -> None: 314 def override_where_cuda() -> None: 315 # Example 1: Invert the behavior of where's condition input 316 not_where_code_string = """ 317 template <typename T> T inverted_where(bool cond, T a, T b){ 318 return !cond ? a : b; 319 } 320 """ 321 jitted_where = _create_jit_fn(not_where_code_string) 322 323 CALLED = [False] 324 325 def inverted_where(*args, **kwargs): 326 CALLED[0] = True 327 return jitted_where(*args, **kwargs) 328 329 # overriding where's cuda kernel with Jiterator generated kernel 330 with _scoped_library("aten", "IMPL") as my_lib: 331 my_lib.impl("aten::where.self", inverted_where, "CUDA") 332 333 device = "cuda" 334 cond = torch.tensor( 335 [True, True, False], device=device, dtype=torch.bool 336 ) 337 x = torch.tensor([1, 2, 3], device=device) 338 y = torch.tensor([-1, -2, -3], device=device) 339 340 self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3])) 341 self.assertTrue(CALLED[0]) 342 343 # behavior restored after deregistration 344 self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3])) 345 346 def override_gelu_cuda() -> None: 347 # Example 2: Use relu to approximate gelu for faster compute 348 fastest_gelu_code_string = """ 349 template <typename T> T fast_gelu(T a){ 350 return a > 0 ? a : 0; 351 } 352 """ 353 jitted_gelu = _create_jit_fn(fastest_gelu_code_string) 354 355 CALLED = [False] 356 357 def fast_gelu(*args, **kwargs): 358 CALLED[0] = True 359 return jitted_gelu(*args, **kwargs) 360 361 # overriding gelu's cuda kernel with Jiterator generated relu kernel 362 with _scoped_library("aten", "IMPL") as my_lib: 363 my_lib.impl("aten::gelu", fast_gelu, "CUDA") 364 365 x = torch.rand([3, 3], device="cuda", dtype=torch.float) 366 self.assertEqual( 367 torch.nn.functional.gelu(x), torch.nn.functional.relu(x) 368 ) 369 self.assertTrue(CALLED[0]) 370 371 # behavior restored after deregistration 372 self.assertNotEqual( 373 torch.nn.functional.gelu(x), torch.nn.functional.relu(x) 374 ) 375 376 def override_exp_cuda() -> None: 377 # Example 3: Preventing exp from exploding for float16 378 clipped_exp_code_string = """ 379 template <typename T> T clipped_exp(T a){ 380 return a > T(10.0) ? T(22026.4657948) : exp(a); 381 } 382 """ 383 jitted_exp = _create_jit_fn(clipped_exp_code_string) 384 385 CALLED = [False] 386 387 def clipped_exp(*args, **kwargs): 388 CALLED[0] = True 389 return jitted_exp(*args, **kwargs) 390 391 # overriding exp's cuda kernel with clipped_exp kernel 392 with _scoped_library("aten", "IMPL") as my_lib: 393 my_lib.impl("aten::exp", clipped_exp, "CUDA") 394 395 x = torch.tensor([0.0, 100.0], device="cuda", dtype=torch.float16) 396 self.assertEqual( 397 torch.exp(x), 398 torch.tensor([1.0, 22026.4657948], dtype=torch.float16), 399 ) 400 self.assertTrue(CALLED[0]) 401 402 # behavior restored after deregistration 403 self.assertEqual( 404 torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16) 405 ) 406 407 def override_add_cuda() -> None: 408 # Example 4: simulate a hardware bug, where the adder is always off by 1 409 buggy_add_code_string = """ 410 template <typename T> T buggy_add(T a, T b){ 411 return a + b + T(1); 412 } 413 """ 414 jitted_add = _create_jit_fn(buggy_add_code_string) 415 416 CALLED = [False] 417 418 def buggy_add(*args, **kwargs): 419 CALLED[0] = True 420 return jitted_add(*args, **kwargs) 421 422 with _scoped_library("aten", "IMPL") as my_lib: 423 my_lib.impl("aten::add.Tensor", buggy_add, "CUDA") 424 425 x_cpu = torch.rand([3, 3], device="cpu") 426 y_cpu = torch.rand([3], device="cpu") 427 428 x_cuda = x_cpu.cuda() 429 y_cuda = y_cpu.cuda() 430 431 self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1) 432 self.assertTrue(CALLED[0]) 433 434 # behavior restored after deregistration 435 self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu) 436 437 if torch.cuda.is_available() and not TEST_WITH_ROCM: 438 override_where_cuda() 439 override_gelu_cuda() 440 override_exp_cuda() 441 override_add_cuda() 442 443 def test_extend_library_with_dispatch_key_arg(self): 444 def my_sum(*args, **kwargs): 445 return args[0].clone() 446 447 with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1: 448 # RuntimeError: Explicitly provided dispatch key (Conjugate) is 449 # inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block 450 with self.assertRaisesRegex( 451 RuntimeError, "inconsistent with the dispatch key" 452 ): 453 my_lib1.impl("sum", my_sum, "Conjugate") 454 my_lib1.impl("aten::sum", my_sum) 455 x = torch.tensor([1, 2]) 456 self.assertEqual(torch.sum(x), x) 457 458 def test_create_new_library(self) -> None: 459 with _scoped_library(self.test_ns, "DEF") as my_lib1: 460 my_lib1.define("sum(Tensor self) -> Tensor") 461 462 # Example 1 463 @torch.library.impl(my_lib1, "sum", "CPU") 464 def my_sum(*args, **kwargs): 465 return args[0].clone() 466 467 x = torch.tensor([1, 2]) 468 op = getattr(torch.ops, self.test_ns).sum 469 self.assertEqual(op(x), x) 470 471 with _scoped_library(self.test_ns, "IMPL") as my_lib2: 472 # Example 2 473 @torch.library.impl(my_lib2, op.default, "ZeroTensor") 474 def my_sum_zt(*args, **kwargs): 475 if args[0]._is_zerotensor(): 476 return torch._efficientzerotensor(args[0].shape) 477 else: 478 return args[0].clone() 479 480 y = torch._efficientzerotensor(3) 481 self.assertTrue(op(y)._is_zerotensor()) 482 self.assertEqual(op(x), x) 483 484 def test_create_new_library_fragment_no_existing(self): 485 with _scoped_library(self.test_ns, "FRAGMENT") as my_lib: 486 my_lib.define("sum2(Tensor self) -> Tensor") 487 488 @torch.library.impl(my_lib, "sum2", "CPU") 489 def my_sum(*args, **kwargs): 490 return args[0] 491 492 x = torch.tensor([1, 2]) 493 self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x) 494 495 def test_create_new_library_fragment_with_existing(self): 496 with _scoped_library(self.test_ns, "DEF") as my_lib1: 497 # Create a fragment 498 with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2: 499 my_lib2.define("sum4(Tensor self) -> Tensor") 500 501 @torch.library.impl(my_lib2, "sum4", "CPU") 502 def my_sum4(*args, **kwargs): 503 return args[0] 504 505 x = torch.tensor([1, 2]) 506 self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x) 507 508 # Create another fragment 509 with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3: 510 my_lib3.define("sum3(Tensor self) -> Tensor") 511 512 @torch.library.impl(my_lib3, "sum3", "CPU") 513 def my_sum3(*args, **kwargs): 514 return args[0] 515 516 x = torch.tensor([1, 2]) 517 self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x) 518 519 @unittest.skipIf(IS_WINDOWS, "Skipped under Windows") 520 def test_alias_analysis(self): 521 def test_helper(alias_analysis=""): 522 my_lib1 = Library(self.test_ns, "DEF") # noqa: TOR901 523 524 called = [0] 525 526 @torch.library.define( 527 my_lib1, "_op() -> None", alias_analysis=alias_analysis 528 ) 529 def _op(*args, **kwargs): 530 called[0] += 1 531 532 @torch.jit.script 533 def _test(): 534 torch.ops._test_python_registration._op() 535 536 assert "_test_python_registration::_op" in str(_test.graph) 537 538 with self.assertRaises(AssertionError): 539 test_helper("") # alias_analysis="FROM_SCHEMA" 540 541 test_helper("CONSERVATIVE") 542 543 def test_error_for_unsupported_ns_or_kind(self) -> None: 544 with self.assertRaisesRegex(ValueError, "Unsupported kind"): 545 my_lib1 = Library("myns", "BLA") # noqa: TOR901 546 547 for kind in ("DEF", "FRAGMENT"): 548 with self.assertRaisesRegex(ValueError, "reserved namespace"): 549 my_lib1 = Library("prim", kind) # noqa: TOR901 550 551 def test_returning_symint(self) -> None: 552 shape_env = ShapeEnv() 553 fake_tensor_mode = FakeTensorMode(shape_env=shape_env) 554 555 ft = fake_tensor_mode.from_tensor(torch.rand(2, 3)) 556 557 s0, s1 = ft.shape 558 559 with _scoped_library(self.test_ns, "DEF") as tlib: 560 tlib.define("sqsum(SymInt a, SymInt b) -> SymInt") 561 562 @impl(tlib, "sqsum", "CompositeExplicitAutograd") 563 def sqsum(a: SymInt, b: SymInt): 564 return a * a + b * b 565 566 out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1) 567 out_val = shape_env.evaluate_expr(out.node.expr) 568 self.assertEqual(out_val, 13) 569 570 def test_register_functional_op_error_cases(self): 571 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 572 with self.assertRaisesRegex(TypeError, "instance of OpOverload"): 573 register_functional_op(lib, "abs", torch.ops.aten.abs_) 574 with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"): 575 register_functional_op(lib, "abs", torch.ops.aten.abs_.default) 576 with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"): 577 register_functional_op(lib, "abs", torch.ops.aten.abs.out) 578 579 schemas = [ 580 "foo(Tensor x, Tensor(a!)[] y) -> ()", 581 "foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)", 582 "foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))", 583 ] 584 585 for schema in schemas: 586 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 587 lib.define(schema) 588 with self.assertRaisesRegex(RuntimeError, "NYI"): 589 register_functional_op( 590 lib, 591 "foo_functional", 592 getattr(torch.ops, self.test_ns).foo.default, 593 ) 594 595 def _check_is_functional_variant(self, mutable_op, functional_op, args): 596 # functional op should not mutate 597 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 598 functional_result = functional_op(*cloned_args) 599 self.assertEqual(cloned_args, args) 600 601 # check functional_result includes mutable_result 602 mutable_result = mutable_op(*cloned_args) 603 if mutable_result is None: 604 flat_mutable_result = [] 605 else: 606 flat_mutable_result = pytree.tree_leaves(mutable_result) 607 flat_functional_result = pytree.tree_leaves(functional_result) 608 assert len(flat_functional_result) > len(flat_mutable_result) 609 self.assertEqual( 610 flat_functional_result[: len(flat_mutable_result)], flat_mutable_result 611 ) 612 613 # check rest of functional_result is the mutated args 614 mutated_args = [ 615 maybe_mutated_arg 616 for maybe_mutated_arg, arg in zip(cloned_args, args) 617 if not ( 618 maybe_mutated_arg is not None 619 and arg is not None 620 and torch.allclose(maybe_mutated_arg, arg) 621 ) 622 ] 623 self.assertEqual( 624 flat_functional_result[len(flat_mutable_result) :], mutated_args 625 ) 626 627 # check that functionalization kernel was indeed registered 628 def fn(*args): 629 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 630 mutable_op(*cloned_args) 631 return cloned_args 632 633 gm = make_fx(torch.func.functionalize(fn))(*args) 634 has_functional_op = False 635 for node in gm.graph.nodes: 636 self.assertFalse(node.target is mutable_op) 637 if node.target is functional_op: 638 has_functional_op = True 639 self.assertTrue(has_functional_op) 640 641 def test_register_functional_op_no_returns(self): 642 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 643 lib.define("foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()") 644 645 def foo_impl(x, y, z, w): 646 y.fill_(3.14) 647 w.fill_(2.71) 648 649 lib.impl("foo", foo_impl, "CPU") 650 register_functional_op( 651 lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default 652 ) 653 x = torch.randn([]) 654 y = torch.randn([]) 655 z = torch.randn([]) 656 w = torch.randn([]) 657 self._check_is_functional_variant( 658 getattr(torch.ops, self.test_ns).foo.default, 659 getattr(torch.ops, self.test_ns).foo_functional.default, 660 (x, y, z, w), 661 ) 662 663 def test_register_functional_op_with_optional(self): 664 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 665 lib.define( 666 "foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()" 667 ) 668 669 def foo_impl(x, y, z, w): 670 y.fill_(3.14) 671 z.fill_(2.71) 672 if w is not None: 673 w.fill_(1.618) 674 675 lib.impl("foo", foo_impl, "CPU") 676 register_functional_op( 677 lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default 678 ) 679 x = torch.randn([]) 680 y = torch.randn([]) 681 z = torch.randn([]) 682 w = torch.randn([]) 683 self._check_is_functional_variant( 684 getattr(torch.ops, self.test_ns).foo.default, 685 getattr(torch.ops, self.test_ns).foo_functional.default, 686 (x, y, z, w), 687 ) 688 self._check_is_functional_variant( 689 getattr(torch.ops, self.test_ns).foo.default, 690 getattr(torch.ops, self.test_ns).foo_functional.default, 691 (x, y, z, None), 692 ) 693 694 def test_register_functional_op_one_return(self): 695 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 696 lib.define( 697 "foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor" 698 ) 699 700 def foo_impl(x, y, z, w): 701 y.fill_(3.14) 702 w.fill_(2.71) 703 z.fill_(0.99) 704 return x.clone() 705 706 lib.impl("foo", foo_impl, "CPU") 707 register_functional_op( 708 lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default 709 ) 710 x = torch.randn([]) 711 y = torch.randn([]) 712 z = torch.randn([]) 713 w = torch.randn([]) 714 self._check_is_functional_variant( 715 getattr(torch.ops, self.test_ns).foo.default, 716 getattr(torch.ops, self.test_ns).foo_functional.default, 717 (x, y, z, w), 718 ) 719 720 def test_register_functional_op_multiple_returns(self): 721 with _scoped_library(self.test_ns, "FRAGMENT") as lib: 722 lib.define( 723 "foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)" 724 ) 725 726 def foo_impl(x, y, z, w): 727 y.fill_(3.14) 728 w.fill_(2.71) 729 return x.clone(), z.clone() 730 731 lib.impl("foo", foo_impl, "CPU") 732 register_functional_op( 733 lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default 734 ) 735 736 x = torch.randn([]) 737 y = torch.randn([]) 738 z = torch.randn([]) 739 w = torch.randn([]) 740 self._check_is_functional_variant( 741 getattr(torch.ops, self.test_ns).foo.default, 742 getattr(torch.ops, self.test_ns).foo_functional.default, 743 (x, y, z, w), 744 ) 745 746 def test_register_fallthrough(self): 747 with _scoped_library("aten", "IMPL") as my_lib: 748 my_lib.impl("mm", fallthrough_kernel, "AutocastCPU") 749 750 a = torch.randn(2, 3, device="cpu", dtype=torch.float32) 751 b = torch.randn(3, 2, device="cpu", dtype=torch.float32) 752 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 753 # dtype for mm should be float32 since we registered a fallthrough 754 self.assertEqual(torch.mm(a, b).dtype, torch.float32) 755 # ops that don't have a fallthrough registered should not be affected 756 self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16) 757 758 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 759 # default behavior should have been restored 760 self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16) 761 762 763class TestPythonDispatch(TestCase): 764 def test_basic(self) -> None: 765 with capture_logs() as logs: 766 x = LoggingTensor(torch.tensor([3.0]), requires_grad=True) 767 log_input("x", x) 768 y = x * x 769 saved_x = y.grad_fn._saved_self 770 grad_y = LoggingTensor(torch.tensor([1.0])) 771 log_input("grad_y", grad_y) 772 (g,) = torch.autograd.grad((y,), (x,), (grad_y,)) 773 774 self.assertEqual(g.elem, torch.tensor([6.0])) 775 with torch.no_grad(): 776 self.assertEqual(saved_x, x) 777 self.assertEqual(saved_x._version, x._version) 778 x.add_(2) 779 self.assertEqual(saved_x, x) 780 # TODO: figure out why broken 781 # self.assertEqual(saved_x._version, x._version) 782 self.assertExpectedInline( 783 "\n".join(logs), 784 """\ 785$0: f32[1] = input('x') 786$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0) 787$2: f32[1] = input('grad_y') 788$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0) 789$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0) 790$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)""", 791 ) 792 793 def test_out(self) -> None: 794 with capture_logs() as logs: 795 x = LoggingTensor(torch.ones(1)) 796 y = LoggingTensor(torch.zeros(1)) 797 log_input("x", x) 798 log_input("y", y) 799 torch.abs(x, out=y) 800 801 self.assertEqual(y.elem, torch.ones(1)) 802 # TODO: arguably this shouldn't pass and we should complain 803 # that out isn't a kwarg 804 self.assertExpectedInline( 805 "\n".join(logs), 806 """\ 807$0: f32[1] = input('x') 808$1: f32[1] = input('y') 809$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)""", 810 ) 811 812 def test_kwarg_only(self) -> None: 813 with capture_logs() as logs: 814 x = LoggingTensor(torch.ones(1)) 815 y = LoggingTensor(torch.ones(1, 1)) 816 z = LoggingTensor(torch.ones(1)) 817 log_input("x", x) 818 log_input("y", y) 819 log_input("z", z) 820 torch.addmv(x, y, z) 821 torch.addmv(x, y, z, beta=1) 822 torch.addmv(x, y, z, beta=2) 823 torch.addmv(x, y, z, alpha=2) 824 torch.addmv(x, y, z, beta=2, alpha=2) 825 826 # The expectation is that beta/alpha don't show up when they're 827 # defaulted. This is even if the user explicitly specified it. 828 self.assertExpectedInline( 829 "\n".join(logs), 830 """\ 831$0: f32[1] = input('x') 832$1: f32[1, 1] = input('y') 833$2: f32[1] = input('z') 834$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2) 835$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2) 836$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2) 837$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2) 838$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)""", 839 ) 840 841 def test_kwarg_only_and_positional_default(self) -> None: 842 with capture_logs() as logs: 843 x = LoggingTensor(torch.ones(1)) 844 log_input("x", x) 845 torch.ops.aten._foobar(x) 846 torch.ops.aten._foobar(x, False) 847 torch.ops.aten._foobar(x, arg3=False) 848 torch.ops.aten._foobar(x, False, arg3=False) 849 850 # What we are testing here is that we omit arg2 851 # if it is defaulted, even if a kwarg is set 852 self.assertExpectedInline( 853 "\n".join(logs), 854 """\ 855$0: f32[1] = input('x') 856$1: f32[1] = torch._ops.aten._foobar.default($0) 857$2: f32[1] = torch._ops.aten._foobar.default($0, False) 858$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False) 859$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""", 860 ) 861 862 def test_produce_real_type(self) -> None: 863 with capture_logs() as logs: 864 x = LoggingTensor(torch.ones(2, 2)) 865 log_input("x", x) 866 x.to(dtype=torch.double) # non-optional dtype 867 torch.cumprod(x, 0, dtype=torch.double) # optional dtype 868 x[:, 1].contiguous( 869 memory_format=torch.contiguous_format 870 ) # optional memory format 871 # There doesn't appear to be any layout signatures which are 872 # triggerable using tensor subclasses (need to use a mode) 873 874 self.assertExpectedInline( 875 "\n".join(logs), 876 """\ 877$0: f32[2, 2] = input('x') 878$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64) 879$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64) 880$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807) 881$4: f32[2] = torch._ops.aten.select.int($3, 1, 1) 882$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""", 883 ) 884 885 def test_optional_tensor_list(self) -> None: 886 def weird(xs): 887 print("woof") 888 return torch.empty(()) 889 890 with _scoped_library("my_lib", "DEF") as my_lib: 891 my_lib.define("weird(Tensor?[] self) -> Tensor") 892 my_lib.impl("weird", weird, "CPU") 893 with capture_logs() as logs: 894 x = LoggingTensor(torch.ones(2, 2)) 895 log_input("x", x) 896 torch.ops.my_lib.weird.default([None, x]) 897 898 self.assertExpectedInline( 899 "\n".join(logs), 900 """\ 901$0: f32[2, 2] = input('x') 902$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""", 903 ) 904 905 def test_list_ret(self) -> None: 906 # test all sequence types are permissible returns 907 for list_type in (list, tuple): 908 909 class A(torch.Tensor): 910 @staticmethod 911 def __new__(cls, elem): 912 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 913 914 @classmethod 915 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 916 if func.overloadpacket == torch.ops.aten.split: 917 with no_dispatch(): 918 return list_type(torch.split(*args)) 919 else: 920 raise AssertionError(f"unrecognized func: {func}") 921 922 self.assertEqual( 923 torch.split(A(torch.tensor([0, 1])), 2), 924 torch.split(torch.tensor([0, 1]), 2), 925 ) 926 927 def test_invalid_ret(self) -> None: 928 # test invalid return gets reasonable error message 929 class A(torch.Tensor): 930 @staticmethod 931 def __new__(cls, elem): 932 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 933 934 @classmethod 935 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 936 return "arf" 937 938 # Wobbles depending on NDEBUG mode of pybind11 939 self.assertRaisesRegex( 940 RuntimeError, 941 "Unable to cast", 942 lambda: A(torch.zeros(1)).neg(), 943 ) 944 self.assertRaisesRegex( 945 RuntimeError, 946 "Unable to cast", 947 lambda: A(torch.zeros(1)).detach(), 948 ) 949 950 def test_detach_appears_twice_when_called_once(self) -> None: 951 with capture_logs() as logs: 952 x = LoggingTensor(torch.tensor([3.0]), requires_grad=True) 953 log_input("x", x) 954 x.detach() 955 # FIXME: We actually want this to emit a single detach. However, 956 # it currently emits two, for reasons unclear to us. Leaving 957 # this test here to make sure we don't regress even further (it 958 # would be bad if calling .detach() once emits 3+ detaches). 959 self.assertExpectedInline( 960 "\n".join(logs), 961 """\ 962$0: f32[1] = input('x') 963$1: f32[1] = torch._ops.aten.detach.default($0) 964$2: f32[1] = torch._ops.aten.detach.default($1)""", 965 ) 966 967 def test_storage(self) -> None: 968 # For now, just make sure it doesn't crash. Ideally, we should 969 # return some virtual storage that is safe to work with 970 x = LoggingTensor(torch.ones(1)) 971 storage = x.untyped_storage() 972 self.assertRaises(RuntimeError, lambda: storage.data_ptr()) 973 974 def test_make_wrapper_subclass_noalloc(self) -> None: 975 # This is ludicrously big (8TB) and this should pass because wrapper 976 # subclasses don't allocate 977 torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,)) 978 979 def test_version(self) -> None: 980 x = LoggingTensor(torch.ones(1)) 981 prev_vc = x._version 982 x.detach().add_(2) 983 cur_vc = x._version 984 self.assertNotEqual(prev_vc, cur_vc) 985 x.data.add_(2) 986 self.assertEqual(cur_vc, x._version) 987 988 def test_subclass_priority(self) -> None: 989 class ErrorA(RuntimeError): 990 pass 991 992 class ErrorB(RuntimeError): 993 pass 994 995 # The big tests for code coverage are test_precedence_semantics in 996 # test_overrides.py; this is just to make sure it is wired up at all 997 # correctly for __torch_dispatch__ 998 class A(torch.Tensor): 999 @staticmethod 1000 def __new__(cls, elem): 1001 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1002 1003 @classmethod 1004 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1005 raise ErrorA 1006 1007 class B(A): 1008 @staticmethod 1009 def __new__(cls, elem): 1010 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1011 1012 @classmethod 1013 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1014 raise ErrorB 1015 1016 self.assertRaises( 1017 ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1))) 1018 ) 1019 self.assertRaises( 1020 ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1))) 1021 ) 1022 self.assertRaises( 1023 ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1))) 1024 ) 1025 self.assertRaises( 1026 ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1))) 1027 ) 1028 1029 def test_format(self) -> None: 1030 x = LoggingTensor(torch.ones(1)) 1031 s1 = str(x) 1032 s2 = repr(x) 1033 s3 = f"{x}" 1034 self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""") 1035 self.assertEqual(s1, s2) 1036 self.assertEqual(s1, s3) 1037 1038 def test_custom_autograd(self) -> None: 1039 escape = [None] 1040 1041 class Square(torch.autograd.Function): 1042 @staticmethod 1043 def forward(ctx, x): 1044 y = x**2 1045 ctx.save_for_backward(x) 1046 return y 1047 1048 @staticmethod 1049 def backward(ctx, grad_output): 1050 assert isinstance(grad_output, LoggingTensor) 1051 (x,) = ctx.saved_tensors 1052 assert isinstance(x, LoggingTensor) 1053 escape[0] = x 1054 return grad_output * 2 * x 1055 1056 with capture_logs() as logs: 1057 x = LoggingTensor(torch.ones(1), requires_grad=True) 1058 log_input("x", x) 1059 x.grad = LoggingTensor(torch.zeros(1)) 1060 log_input("x.grad", x.grad) 1061 y = Square.apply(x) 1062 grad_output = LoggingTensor(torch.ones(1)) 1063 log_input("grad_output", grad_output) 1064 y.backward(grad_output) 1065 1066 with torch.no_grad(): 1067 self.assertEqual(escape[0], x) 1068 self.assertEqual(escape[0]._version, x._version) 1069 # TODO: figure out why x.requires_grad = False doesn't 1070 # trigger an error for LoggingTensor 1071 x.add_(2) 1072 self.assertEqual(escape[0], x) 1073 # TODO: figure out why this is broken 1074 # self.assertEqual(escape[0]._version, x._version) 1075 1076 self.assertExpectedInline( 1077 "\n".join(logs), 1078 """\ 1079$0: f32[1] = input('x') 1080$1: f32[1] = input('x.grad') 1081$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2) 1082$3: f32[1] = input('grad_output') 1083$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2) 1084$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0) 1085$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)""", 1086 ) 1087 1088 def test_subclass_creation(self): 1089 # Make sure these statements runs without error 1090 # In particular checking that when internal detach returns 1091 # subclasses, these are cleanly overwritten. 1092 class Foo(torch.Tensor): 1093 pass 1094 1095 err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor" 1096 with self.assertRaisesRegex(RuntimeError, err_msg): 1097 a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2))) 1098 with self.assertRaisesRegex(RuntimeError, err_msg): 1099 b = LoggingTensor(torch.rand(2)).as_subclass(Foo) 1100 with self.assertRaisesRegex(RuntimeError, err_msg): 1101 Foo(LoggingTensor(torch.rand(2))) 1102 1103 with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"): 1104 torch.Tensor._make_wrapper_subclass(Foo, (2, 2)) 1105 1106 def test_new_ones(self) -> None: 1107 class MyTensor(torch.Tensor): 1108 @classmethod 1109 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1110 return MyTensor(3) 1111 1112 self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor) 1113 1114 def test_like(self) -> None: 1115 class MyTensor(torch.Tensor): 1116 @classmethod 1117 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1118 return MyTensor(3) 1119 1120 for f in ["empty", "ones", "rand", "randn", "zeros"]: 1121 f_name = f + "_like" 1122 self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor) 1123 1124 self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor) 1125 self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor) 1126 1127 def test_make_fx_with_subclass(self) -> None: 1128 def f(x, y): 1129 # Returns (TwoTensor, Tensor) 1130 return x * y, y + y 1131 1132 x_a = torch.zeros(4) 1133 x_b = torch.zeros(4) 1134 y = torch.ones(4) 1135 1136 # make_fx() is not responsible for unwrapping tensor subclass inputs, 1137 # so we do it manually here. 1138 # Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling 1139 # convention as f(*args). Unwrapping tensor subclass inputs can potentially change 1140 # the number of input args to the graph, breaking that assumption 1141 def f_to_trace(x_a, x_b, y): 1142 x = TwoTensor(x_a, x_b) 1143 out1, out2 = f(x, y) 1144 out1_unwrapped_attrs, _ = out1.__tensor_flatten__() 1145 return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2) 1146 1147 fx_g = make_fx(f_to_trace, tracing_mode="fake")(x_a, x_b, y) 1148 self.assertExpectedInline( 1149 fx_g.code, 1150 """\ 1151 1152 1153 1154def forward(self, x_a_1, x_b_1, y_1): 1155 mul = torch.ops.aten.mul.Tensor(x_a_1, y_1); x_a_1 = None 1156 mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1); x_b_1 = None 1157 add = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None 1158 return (mul, mul_1, add) 1159 """, 1160 ) 1161 1162 # See https://github.com/pytorch/pytorch/issues/117794 1163 def test_return_and_correct_aliasing_gives_correct_stride(self): 1164 t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2)) 1165 x = torch.randn(2, 2) 1166 # slicing should result in the same stride for TwoTensor as a dense tensor would give 1167 self.assertEqual(t[:, 0].stride(), x[:, 0].stride()) 1168 1169 def test_make_wrapper_subclass_propagates_metadata(self) -> None: 1170 class WrapperTensor(torch.Tensor): 1171 elem: torch.Tensor 1172 1173 __slots__ = ["elem"] 1174 1175 @staticmethod 1176 def __new__(cls, elem, *args, **kwargs): 1177 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 1178 cls, 1179 elem.size(), 1180 dtype=elem.dtype, 1181 layout=elem.layout, 1182 device=elem.device, 1183 requires_grad=elem.requires_grad, 1184 strides=elem.stride(), 1185 storage_offset=elem.storage_offset(), 1186 ) 1187 r.elem = elem 1188 return r 1189 1190 @classmethod 1191 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1192 raise RuntimeError("NYI") 1193 1194 # non-contiguous strides, non-zero storage offset 1195 x = torch.randn(4, 6).t().diagonal(offset=2) 1196 y = WrapperTensor(x) 1197 self.assertEqual(y.size(), x.size()) 1198 self.assertEqual(y.stride(), x.stride()) 1199 self.assertEqual(y.storage_offset(), x.storage_offset()) 1200 1201 def test_wrapper_subclass_serializes(self) -> None: 1202 with tempfile.TemporaryFile() as f: 1203 # purposefully use int64 to test non-default dtype 1204 x = LoggingTensor(torch.randperm(3)) 1205 torch.save(x, f) 1206 f.seek(0) 1207 with torch.serialization.safe_globals([LoggingTensor]): 1208 x_loaded = torch.load(f) 1209 self.assertTrue(type(x_loaded) is type(x)) 1210 self.assertEqual(x, x_loaded) 1211 self.assertEqual(x.elem, x_loaded.elem) 1212 self.assertFalse(x is x_loaded) 1213 1214 def test_deepcopy_wrapper_subclass(self) -> None: 1215 # purposefully use int64 to test non-default dtype 1216 x = LoggingTensor(torch.randperm(3)) 1217 x_copy = deepcopy(x) 1218 self.assertTrue(type(x_copy) is type(x)) 1219 self.assertEqual(x, x_copy) 1220 self.assertEqual(x.elem, x_copy.elem) 1221 self.assertFalse(x is x_copy) 1222 1223 def test_deepcopy_wrapper_subclass_with_clone_returning_different_type( 1224 self, 1225 ) -> None: 1226 class MyWrapperTensor(torch.Tensor): 1227 elem: torch.Tensor 1228 1229 __slots__ = ["elem"] 1230 1231 @staticmethod 1232 def __new__(cls, elem, *args, **kwargs): 1233 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 1234 cls, 1235 elem.size(), 1236 dtype=elem.dtype, 1237 layout=elem.layout, 1238 device=elem.device, 1239 requires_grad=elem.requires_grad, 1240 strides=elem.stride(), 1241 storage_offset=elem.storage_offset(), 1242 ) 1243 r.elem = elem 1244 return r 1245 1246 @classmethod 1247 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1248 if func.overloadpacket.__name__ == "clone": 1249 # Return a plain tensor from clone(). 1250 return args[0].elem.clone() 1251 raise RuntimeError("NYI") 1252 1253 # NB: The default Tensor.__torch_function__ implementation called for deepcopy 1254 # disables __torch_function__ by the time we get to clone(), so there is no need to 1255 # explicitly disable __torch_function__ for this subclass. 1256 1257 x = MyWrapperTensor(torch.randn(3)) 1258 with self.assertRaisesRegex( 1259 RuntimeError, 1260 "for which cloning returns another instance of the same subclass", 1261 ): 1262 x_copy = deepcopy(x) 1263 1264 def test_deepcopy_non_wrapper_subclass(self) -> None: 1265 # Ensure correct error is thrown for common error cases. 1266 class SubTensorError1(torch.Tensor): 1267 # Default implementation of new_empty() returns a plain tensor. 1268 pass 1269 1270 class SubTensorError2(torch.Tensor): 1271 # new_empty() incorrectly returns a different type (i.e. a plain tensor). 1272 def new_empty(self, shape): 1273 return torch.Tensor(shape) 1274 1275 for error_cls in [SubTensorError1, SubTensorError2]: 1276 x = error_cls(3) 1277 with self.assertRaisesRegex( 1278 RuntimeError, 1279 "for which that function returns another instance of the same subclass", 1280 ): 1281 x_copy = deepcopy(x) 1282 1283 # Ensure a correctly implemented new_empty() causes deepcopy() to work. 1284 class SubTensorSuccess(torch.Tensor): 1285 def new_empty(self, shape): 1286 return type(self)(shape) 1287 1288 x = SubTensorSuccess(3) 1289 x_copy = deepcopy(x) 1290 self.assertIs(type(x_copy), type(x)) 1291 1292 def test_wrapper_subclass_extra_dispatch_keys(self) -> None: 1293 class ExtraKeysTensor(torch.Tensor): 1294 @staticmethod 1295 def __new__(cls, elem, *args, **kwargs): 1296 # NB: only the non-kwarg overload of _make_wrapper_subclass supports 1297 # extra dispatch keys. We probably want to unify the two APIs 1298 # in the future. 1299 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 1300 cls, 1301 elem.size(), 1302 elem.stride(), 1303 elem.storage_offset(), 1304 torch.contiguous_format, 1305 elem.dtype, 1306 elem.layout, 1307 elem.device, 1308 False, 1309 False, 1310 None, 1311 False, 1312 False, 1313 DispatchKeySet(DispatchKey.NestedTensor), 1314 ) 1315 return r 1316 1317 @classmethod 1318 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1319 pass 1320 1321 x = ExtraKeysTensor(torch.randn(3)) 1322 self.assertTrue(torch._C._dispatch_keys(x).has(DispatchKey.NestedTensor)) 1323 self.assertFalse( 1324 torch._C._dispatch_keys(x).has(DispatchKey.AutogradNestedTensor) 1325 ) 1326 1327 def test_wrapper_subclass_multiprocessing_preserves_dtype(self): 1328 # a and b have dtype of int64, which is purposefully different from the default 1329 # assumed by _make_wrapper_subclass(). 1330 a = torch.randperm(5) 1331 b = torch.randperm(5) 1332 data = TwoTensor(a, b) 1333 expected_dtype = data.dtype 1334 1335 loader = torch.utils.data.DataLoader( 1336 [data, data], 1337 batch_size=2, 1338 num_workers=2, 1339 collate_fn=_identity, 1340 ) 1341 for batch in loader: 1342 self.assertEqual(batch[0].dtype, expected_dtype) 1343 1344 def test_index_put_where_only_index_is_subclass(self) -> None: 1345 called_funcs = [] 1346 1347 class MyTensor(torch.Tensor): 1348 elem: torch.Tensor 1349 __slots__ = ["elem"] 1350 1351 @staticmethod 1352 def __new__(cls, elem, *args, **kwargs): 1353 r = torch.Tensor._make_wrapper_subclass( 1354 cls, 1355 elem.size(), 1356 dtype=elem.dtype, 1357 layout=elem.layout, 1358 device=elem.device, 1359 requires_grad=elem.requires_grad, 1360 ) 1361 r.elem = elem 1362 return r 1363 1364 @classmethod 1365 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1366 called_funcs.append(func) 1367 return MyTensor(torch.tensor(3)) 1368 1369 x = torch.randn(3, 3) 1370 idxs = (MyTensor(torch.tensor(0)),) 1371 v = torch.randn(1) 1372 res = x.index_put_(idxs, v) 1373 self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default]) 1374 1375 def test_torch_dispatch_mode_basic(self) -> None: 1376 with capture_logs(is_mode=True) as logs: 1377 with LoggingTensorMode(): 1378 torch.empty([]) 1379 self.assertExpectedInline( 1380 "\n".join(logs), 1381 """\ 1382$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""", 1383 ) 1384 1385 def test_torch_dispatch_mode_unrelated_tensors(self) -> None: 1386 x = torch.randn([]) 1387 y = torch.randn([]) 1388 with capture_logs(is_mode=True) as logs: 1389 with LoggingTensorMode(): 1390 x + y 1391 self.assertExpectedInline( 1392 "\n".join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)""" 1393 ) 1394 1395 def test_nested_push_logging_tensor_mode(self): 1396 x = torch.randn([]) 1397 y = torch.randn([]) 1398 with capture_logs(is_mode=True) as logs: 1399 with LoggingTensorMode(): 1400 with LoggingTensorMode(): 1401 torch.empty([]) 1402 x + y 1403 1404 self.assertExpectedInline( 1405 "\n".join(logs), 1406 """\ 1407$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1408$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1409$3: f32[] = torch._ops.aten.add.Tensor($1, $2) 1410$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""", 1411 ) 1412 1413 def test_capture_logs_with_torch_dispatch_mode(self): 1414 x = torch.randn([]) 1415 y = torch.randn([]) 1416 with capture_logs_with_logging_tensor_mode() as logs: 1417 torch.empty([]) 1418 x + y 1419 self.assertExpectedInline( 1420 "\n".join(logs), 1421 """\ 1422$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1423$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""", 1424 ) 1425 1426 x = torch.randn([]) 1427 y = torch.randn([]) 1428 1429 with capture_logs_with_logging_tensor_mode() as logs1: 1430 with capture_logs_with_logging_tensor_mode() as logs2: 1431 torch.empty([]) 1432 x + y 1433 1434 self.assertExpectedInline( 1435 "\n".join(logs2), 1436 """\ 1437$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1438$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1439$3: f32[] = torch._ops.aten.add.Tensor($1, $2) 1440$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""", 1441 ) 1442 1443 self.assertEqual(logs1, logs2) 1444 1445 def test_torch_dispatch_mode_subclass_priority(self) -> None: 1446 class ErrorA(RuntimeError): 1447 pass 1448 1449 class ErrorB(RuntimeError): 1450 pass 1451 1452 class A(torch.Tensor): 1453 @staticmethod 1454 def __new__(cls, elem): 1455 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1456 1457 @classmethod 1458 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1459 with AMode(): 1460 raise ErrorA 1461 1462 class B(A): 1463 @staticmethod 1464 def __new__(cls, elem): 1465 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1466 1467 @classmethod 1468 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1469 with BMode(): 1470 func(*args, **kwargs) 1471 1472 class AMode(TorchDispatchMode): 1473 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1474 raise ErrorA 1475 1476 class BMode(TorchDispatchMode): 1477 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1478 raise ErrorB 1479 1480 a = A(torch.empty(1)) 1481 b = B(torch.empty(1)) 1482 with self.assertRaises(ErrorA): 1483 a + a 1484 with self.assertRaises(ErrorB): 1485 a + b 1486 1487 # B has precedence over A due to the subclass relationship yet 1488 # modes take precedence over arguments 1489 with self.assertRaises(ErrorA): 1490 with AMode(): 1491 b + b 1492 with self.assertRaises(ErrorB): 1493 with BMode(): 1494 a + a 1495 with self.assertRaises(ErrorB): 1496 with BMode(): 1497 a + b 1498 1499 def test_mode_with_make_subclass(self): 1500 class SubTensor(torch.Tensor): 1501 @staticmethod 1502 def __new__(cls, elem): 1503 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1504 1505 class BasicMode(TorchDispatchMode): 1506 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1507 return func(*args, **kwargs) 1508 1509 x = torch.randn(3) 1510 with BasicMode(): 1511 y = SubTensor(x) 1512 self.assertIsInstance(y, SubTensor) 1513 1514 def test_torch_dispatch_mode_respects_no_dispatch(self) -> None: 1515 with capture_logs(is_mode=True) as logs1: 1516 with LoggingTensorMode(): 1517 torch.ones([2, 3]) 1518 with no_dispatch(): 1519 torch.ones([2, 3]) 1520 with capture_logs(is_mode=True) as logs2: 1521 with LoggingTensorMode(): 1522 torch.ones([2, 3]) 1523 self.assertEqual(logs1, logs2) 1524 1525 def test_shallow_copy_and_detach(self) -> None: 1526 seen = set() 1527 test_case = self 1528 1529 class TestMode(TorchDispatchMode): 1530 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1531 tree_map_only( 1532 torch.Tensor, lambda t: test_case.assertIn(t, seen), (args, kwargs) 1533 ) 1534 if kwargs is None: 1535 kwargs = {} 1536 r = func(*args, **kwargs) 1537 tree_map_only(torch.Tensor, lambda t: seen.add(t), r) 1538 return r 1539 1540 with TestMode(): 1541 x = torch.randn(3, requires_grad=True) 1542 loss = (x * x).sum() 1543 loss.backward() 1544 1545 def test_exception_handling(self): 1546 class A(torch.Tensor): 1547 @staticmethod 1548 def __new__(cls, elem): 1549 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1550 1551 class AMode(TorchDispatchMode): 1552 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1553 if func.__name__ == "randn.default": 1554 raise RuntimeError 1555 return A(torch.zeros(())) 1556 1557 with AMode(): 1558 try: 1559 torch.randn(()) 1560 except RuntimeError: 1561 pass 1562 self.assertTrue(isinstance(torch.zeros(()), A)) 1563 1564 def test_with_mode_created_separately(self): 1565 class ErrorA(RuntimeError): 1566 pass 1567 1568 class A(TorchDispatchMode): 1569 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1570 raise ErrorA 1571 1572 x = A() 1573 with self.assertRaises(ErrorA): 1574 with x: 1575 torch.empty([]) 1576 1577 def test_with_nested_modes(self): 1578 class ErrorA(RuntimeError): 1579 def __init__(self, msg): 1580 super().__init__(msg) 1581 1582 class A(TorchDispatchMode): 1583 def __init__(self, msg): 1584 self.msg = msg 1585 1586 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1587 raise ErrorA(self.msg) 1588 1589 with self.assertRaisesRegex(ErrorA, "layer2"): 1590 with A("layer1"): 1591 with A("layer2"): 1592 torch.empty([]) 1593 1594 def test_make_subclass_with_modes(self): 1595 class ModeTensor(torch.Tensor): 1596 def __new__(cls, elem, mode): 1597 r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1598 r.elem = elem 1599 r.mode = mode 1600 return r 1601 1602 @classmethod 1603 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1604 raise NotImplementedError("Shouldn't be here") 1605 1606 class Mode(TorchDispatchMode): 1607 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1608 def unwrap(e): 1609 if isinstance(e, ModeTensor): 1610 return e.elem 1611 else: 1612 return e 1613 1614 def wrap(t): 1615 if isinstance(t, torch.Tensor): 1616 return ModeTensor(t, self) 1617 else: 1618 return t 1619 1620 return wrap(func(*tuple(unwrap(a) for a in args), **kwargs)) 1621 1622 class BasicMode(TorchDispatchMode): 1623 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1624 return func(*args, **kwargs) 1625 1626 x = torch.tensor(4.0) 1627 with Mode(): 1628 y = x + x 1629 z = y + y 1630 self.assertIsInstance(y, ModeTensor) 1631 self.assertIsInstance(z, ModeTensor) 1632 1633 with Mode(): 1634 with BasicMode(): # we can't nest two modes that call make_subclass because it only accepts vanilla tensors 1635 y = x + x 1636 z = y + y 1637 self.assertIsInstance(y, ModeTensor) 1638 self.assertIsInstance(z, ModeTensor) 1639 1640 assert self.assertRaisesRegex( 1641 RuntimeError, 1642 "subclass Mode but.* associated to a python object of type Mode", 1643 ) 1644 1645 def test_notimplemented_mode(self): 1646 sub_count = 0 1647 1648 class PoliteMode(TorchDispatchMode): 1649 def __init__(self) -> None: 1650 self.pre_count = 0 1651 self.post_count = 0 1652 1653 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1654 self.pre_count += 1 1655 if any(t is not torch.Tensor for t in types): 1656 return NotImplemented 1657 self.post_count += 1 1658 return func(*args, **kwargs) 1659 1660 class SubTensor(torch.Tensor): 1661 def __new__(cls, elem): 1662 r = torch.Tensor._make_wrapper_subclass(cls, elem.shape) 1663 r.elem = elem 1664 return r 1665 1666 @classmethod 1667 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1668 nonlocal sub_count 1669 sub_count += 1 1670 1671 def unwrap(t): 1672 if isinstance(t, SubTensor): 1673 return t.elem 1674 else: 1675 return t 1676 1677 return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 1678 1679 a = SubTensor(torch.randn(2)) 1680 with PoliteMode() as mode: 1681 a.abs() 1682 1683 self.assertEqual(mode.pre_count, 2) 1684 self.assertEqual(mode.post_count, 1) 1685 self.assertEqual(sub_count, 1) 1686 1687 # make sure this doesn't error 1688 with PoliteMode(): 1689 with PoliteMode(): 1690 a.abs() 1691 1692 def test_nesting_same_mode(self): 1693 # If the pushed mode is the same instance as the current mode, we allow pushing an already active mode. 1694 1695 with capture_logs(is_mode=True) as logs: 1696 with LoggingTensorMode() as reenabled: 1697 with reenabled: 1698 torch.empty([]) 1699 self.assertExpectedInline( 1700 "\n".join(logs), 1701 """\ 1702$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1703$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""", 1704 ) 1705 1706 def test_error_using_class_method_on_mode(self): 1707 class A(TorchDispatchMode): 1708 @classmethod 1709 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1710 return func(args, kwargs) 1711 1712 x = torch.tensor(5.0) 1713 with self.assertRaisesRegex( 1714 RuntimeError, "classmethod is not supported, please make it a plain method" 1715 ): 1716 with A(): 1717 x + x 1718 1719 def test_get_cur_mode(self): 1720 class A(TorchDispatchMode): 1721 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1722 pass 1723 1724 self.assertEqual(_get_current_dispatch_mode(), None) 1725 1726 with A() as mode1: 1727 self.assertEqual(_get_current_dispatch_mode(), mode1) 1728 1729 with mode1: 1730 with A() as mode2: 1731 self.assertEqual(_get_current_dispatch_mode(), mode2) 1732 1733 def test_get_mode_stack(self): 1734 class A(TorchDispatchMode): 1735 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1736 pass 1737 1738 self.assertEqual(_get_current_dispatch_mode_stack(), []) 1739 1740 with A() as mode1: 1741 self.assertEqual(_get_current_dispatch_mode_stack(), [mode1]) 1742 1743 with mode1: 1744 with A() as mode2: 1745 self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2]) 1746 1747 def test_all_same_mode(self): 1748 x = LoggingTensorMode() 1749 y = LoggingTensorMode() 1750 self.assertTrue(all_same_mode([x, x, x])) 1751 self.assertFalse(all_same_mode([x, None])) 1752 self.assertFalse(all_same_mode([x, y])) 1753 1754 def test_mode_detection(self): 1755 class InfraMode(TorchDispatchMode): 1756 @classmethod 1757 def is_infra_mode(cls): 1758 return True 1759 1760 class NonInfraMode(TorchDispatchMode): 1761 pass 1762 1763 with InfraMode(): 1764 self.assertTrue(is_in_torch_dispatch_mode()) 1765 self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False)) 1766 with NonInfraMode(): 1767 self.assertTrue(is_in_torch_dispatch_mode()) 1768 self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False)) 1769 with InfraMode(): 1770 self.assertTrue(is_in_torch_dispatch_mode()) 1771 self.assertTrue( 1772 is_in_torch_dispatch_mode(include_infra_modes=False) 1773 ) 1774 1775 self.assertTrue(is_in_torch_dispatch_mode()) 1776 self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False)) 1777 self.assertTrue(is_in_torch_dispatch_mode()) 1778 self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False)) 1779 1780 self.assertFalse(is_in_torch_dispatch_mode()) 1781 self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False)) 1782 1783 def test_tolist_numpy_with_torch_dispatch_mode(self) -> None: 1784 x = LoggingTensor(torch.tensor([2.0, 3.0])) 1785 with self.assertRaisesRegex( 1786 RuntimeError, "is not supported for tensor subclasses." 1787 ): 1788 x.tolist() 1789 with self.assertRaisesRegex( 1790 RuntimeError, "is not supported for tensor subclasses." 1791 ): 1792 x.numpy() 1793 with self.assertRaises(AssertionError): 1794 self.assertEqual(x, None) 1795 1796 def test_record_stream(self) -> None: 1797 class TestMode(TorchDispatchMode): 1798 def __init__(self, testcase): 1799 self.testcase = testcase 1800 1801 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1802 self.testcase.assertEqual(func.name(), "aten::record_stream") 1803 self.testcase.assertIsInstance(args[0], torch.Tensor) 1804 self.testcase.assertIsInstance(args[1], torch.Stream) 1805 self.testcase.assertEqual(args[1].stream_id, 1) 1806 self.testcase.assertEqual(args[1].device_index, 2) 1807 self.testcase.assertEqual(args[1].device_type, 3) 1808 1809 t = torch.tensor(5.0) 1810 s = torch.Stream(stream_id=1, device_index=2, device_type=3) 1811 with TestMode(self): 1812 t.record_stream(s) 1813 1814 def test_return_stream(self) -> None: 1815 with _scoped_library("test_return_stream", "DEF") as l_def: 1816 l_def.define("return_stream(Tensor self) -> Stream") 1817 with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl: 1818 l_impl.impl( 1819 "return_stream", 1820 lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2), 1821 ) 1822 1823 class TestMode(TorchDispatchMode): 1824 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1825 return torch.Stream(stream_id=1, device_index=2, device_type=3) 1826 1827 t = torch.tensor(5.0) 1828 s = torch.ops.test_return_stream.return_stream(t) 1829 self.assertIsInstance(s, torch.Stream) 1830 self.assertEqual(s.stream_id, 0) 1831 self.assertEqual(s.device_index, 1) 1832 self.assertEqual(s.device_type, 2) 1833 1834 with TestMode(): 1835 s = torch.ops.test_return_stream.return_stream(t) 1836 self.assertIsInstance(s, torch.Stream) 1837 self.assertEqual(s.stream_id, 1) 1838 self.assertEqual(s.device_index, 2) 1839 self.assertEqual(s.device_type, 3) 1840 1841 def test_subclass_autograd_device_check(self) -> None: 1842 class NonWrapperSubclass(torch.Tensor): 1843 elem: torch.Tensor 1844 1845 __slots__ = ["elem"] 1846 1847 @staticmethod 1848 def __new__(cls, elem, *args, **kwargs): 1849 # Wrong device here! 1850 r = torch.Tensor._make_subclass( 1851 cls, elem.to("meta"), elem.requires_grad 1852 ) 1853 # ...the real tensor is held as an element on the tensor. 1854 r.elem = elem 1855 return r 1856 1857 @classmethod 1858 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1859 def unwrap(e): 1860 return e.elem if isinstance(e, NonWrapperSubclass) else e 1861 1862 def wrap(e): 1863 return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e 1864 1865 rs = tree_map( 1866 wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 1867 ) 1868 logging.getLogger("NonWrapperSubclass").info( 1869 f"{func.__module__}.{func.__name__}", # noqa: G004 1870 args, 1871 kwargs, 1872 rs, 1873 ) 1874 return rs 1875 1876 x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True)) 1877 y = torch.randn(2, requires_grad=True) 1878 z = x * y 1879 self.assertIsInstance(z, NonWrapperSubclass) 1880 z.sum().backward(torch.tensor(1)) 1881 self.assertEqual(x.grad, y) 1882 self.assertEqual(y.grad, x) 1883 1884 def test_none_wrapping(self): 1885 # A Tensor subclass that returns None when doing add 1886 # See LoggingTensor above for more details on the subclass 1887 class SubclassWithNone(torch.Tensor): 1888 @staticmethod 1889 def __new__(cls, elem, *args, **kwargs): 1890 r = torch.Tensor._make_wrapper_subclass( 1891 cls, 1892 elem.size(), 1893 dtype=elem.dtype, 1894 layout=elem.layout, 1895 device=elem.device, 1896 requires_grad=elem.requires_grad, 1897 ) 1898 r.elem = elem 1899 return r 1900 1901 @classmethod 1902 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1903 def unwrap(e): 1904 return e.elem if isinstance(e, SubclassWithNone) else e 1905 1906 def wrap(e): 1907 return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e 1908 1909 rs = tree_map( 1910 wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 1911 ) 1912 if func.overloadpacket.__name__ == "add": 1913 return None 1914 else: 1915 return rs 1916 1917 x = SubclassWithNone(torch.rand(2)) 1918 # Make sure both run without error 1919 self.assertIsInstance(x * 2, SubclassWithNone) 1920 self.assertIsNone(x + 2) 1921 1922 x.requires_grad_() 1923 out = x.acos().sum() 1924 1925 # The backward of acos does add then rsqrt so here we make sure that the 1926 # undefined Tensor generated by the user code is nicely handled. 1927 # If acos formula changes in the future, this can be replaced by any other 1928 # function that does add then something in the backward in a composite way 1929 with self.assertRaisesRegex(RuntimeError, "but got None"): 1930 out.backward() 1931 1932 def test_storage_can_be_converted_to_python_object(self): 1933 s = torch.Storage() 1934 z = LoggingTensor(torch.empty([])) 1935 z.set_(s) 1936 1937 def test_autograd_in_attr(self): 1938 # We want the wrapped Tensor to require gradients! 1939 true_t = torch.rand(2, requires_grad=True) 1940 t = LoggingTensorReentrant(true_t) 1941 1942 out = t + 2 1943 1944 self.assertFalse(out.requires_grad) 1945 self.assertIsNone(out.grad_fn) 1946 1947 self.assertTrue(out.elem.requires_grad) 1948 self.assertIsNotNone(out.elem.grad_fn) 1949 1950 with self.assertRaisesRegex(RuntimeError, "does not require grad"): 1951 out.sum().backward() 1952 1953 out.elem.sum().backward() 1954 1955 self.assertIsNone(t.grad) 1956 self.assertIsNotNone(t.elem.grad) 1957 1958 def test_dispatch_super_call(self): 1959 called = [] 1960 1961 class SubTensor(torch.Tensor): 1962 @staticmethod 1963 def __new__(cls, elem): 1964 return torch.Tensor._make_subclass(cls, elem) 1965 1966 @classmethod 1967 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1968 called.append(func) 1969 return super().__torch_dispatch__(func, types, args, kwargs) 1970 1971 x = torch.randn(2) 1972 y = torch.randn(2) 1973 self.assertEqual(SubTensor(x) + SubTensor(y), x + y) 1974 self.assertEqual(called, [torch.ops.aten.add.Tensor]) 1975 1976 def test_dispatch_super_call_list_arg(self): 1977 called = [] 1978 1979 class SubTensorWithListArg(torch.Tensor): 1980 @staticmethod 1981 def __new__(cls, elem): 1982 return torch.Tensor._make_subclass(cls, elem) 1983 1984 @classmethod 1985 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1986 called.append(func) 1987 return super().__torch_dispatch__(func, types, list(args), kwargs) 1988 1989 x = torch.randn(2) 1990 self.assertEqual(SubTensorWithListArg(x).neg(), x.neg()) 1991 self.assertEqual(called, [torch.ops.aten.neg.default]) 1992 1993 def test_dispatch_super_dont_autograd(self): 1994 called = [] 1995 1996 class SubTensor(torch.Tensor): 1997 @staticmethod 1998 def __new__(cls, elem): 1999 return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 2000 2001 @classmethod 2002 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 2003 called.append(func) 2004 # This argument still requires grad because it was passed 2005 # through directly... 2006 self.assertTrue(args[0].requires_grad) 2007 r = super().__torch_dispatch__(func, types, args, kwargs) 2008 # But the output better not require grad, because that means 2009 # you did autograd again in torch dispatch (oops) 2010 self.assertFalse(r.requires_grad) 2011 return r 2012 2013 x = SubTensor(torch.randn(2, requires_grad=True)) 2014 x.neg() 2015 self.assertEqual(called, [torch.ops.aten.neg.default]) 2016 2017 def test_set_data(self): 2018 called = 0 2019 2020 class SubTensor(torch.Tensor): 2021 @classmethod 2022 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 2023 nonlocal called 2024 called += 1 2025 return super().__torch_dispatch__(func, types, args, kwargs) 2026 2027 x = SubTensor(torch.empty(2)) 2028 x.data 2029 self.assertEqual(called, 1) 2030 x.data = torch.empty(2) 2031 self.assertEqual(called, 1) 2032 x.data 2033 self.assertEqual(called, 2) 2034 self.assertIs(type(x), SubTensor) 2035 x.set_(torch.empty(2)) 2036 self.assertEqual(called, 3) 2037 x.data 2038 self.assertEqual(called, 4) 2039 self.assertIs(type(x), SubTensor) 2040 2041 def test_construct_int_tensor(self): 2042 class SubTensor(torch.Tensor): 2043 pass 2044 2045 # should not fail 2046 SubTensor(torch.zeros(2, dtype=torch.int)) 2047 2048 def test_multiple_ops_subclass(self): 2049 # This is a Direct Subclass, don't do that! 2050 class MySubclass(torch.Tensor): 2051 @staticmethod 2052 def __new__(cls, elem): 2053 r = torch.Tensor._make_subclass(cls, elem) 2054 return r 2055 2056 @classmethod 2057 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 2058 with no_dispatch(): 2059 return func(*args, **kwargs) 2060 2061 x = MySubclass(torch.rand(2, 2, dtype=torch.complex64)) 2062 y = x.conj() 2063 # Details of the bug that this tests for: 2064 # Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU} 2065 # There are a few calls to the dispatcher that are going to happen here: 2066 # - call_exp: User calling exp on y 2067 # - PythonTLSSnapshot: records the TLS on entry and redispatch 2068 # - AutogradCPU: no input requires grad, so does nothing and redispatch 2069 # - Conjugate: no special implementation for exp: use the fallback that 2070 # first clone the Tensor (to materialize the conj) then redispatch 2071 # - call_clone: conjugate fallback calling clone on y 2072 # - PythonTLSSnapshot: records the TLS on entry and redispatch 2073 # - (AutogradCPU: skipped as autograd added itself to the exclude set above) 2074 # - Conjugate: special implementation for clone: just skip this key 2075 # - Python: Reset the TLS based on the snapshot above and call the user implementation (this 2076 # actually calls into the dispatcher again but since we disable both our keys 2077 # before, not detailed here) 2078 # - exit Python: restore the TLS and exit 2079 # - exit Conjugate: nothing was inplace so just exit 2080 # - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty 2081 # - Python: Reset the TLS again based on the snapshot. <- this used to fail 2082 # - More steps.... 2083 y.exp() 2084 2085 @staticmethod 2086 def subclass_helper(cls, data, use_wrapper_subclass, **kwargs): 2087 if use_wrapper_subclass: 2088 kwargs["device"] = data.device 2089 kwargs["dtype"] = data.dtype 2090 kwargs["layout"] = data.layout 2091 kwargs["requires_grad"] = True 2092 return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined] 2093 else: 2094 return torch.Tensor._make_subclass(cls, data, True, **kwargs) 2095 2096 def test_is_contiguous_slow_path(self): 2097 data = torch.randn(3, 3) 2098 contiguous_data = data.clone() 2099 not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2)) 2100 2101 for use_wrapper_subclass in [True, False]: 2102 2103 class ExampleTensor1(torch.Tensor): 2104 @staticmethod 2105 def __new__(cls, data, wrapper): 2106 return TestPythonDispatch.subclass_helper( 2107 cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2108 ) 2109 2110 @classmethod 2111 def __torch_dispatch__(cls, func, types, args, kwargs): 2112 return NotImplemented 2113 2114 class ExampleTensor2(torch.Tensor): 2115 @staticmethod 2116 def __new__(cls, data, wrapper): 2117 return TestPythonDispatch.subclass_helper( 2118 cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2119 ) 2120 2121 @classmethod 2122 def __torch_dispatch__(cls, func, types, args, kwargs): 2123 if func.overloadpacket == torch.ops.aten.is_contiguous: 2124 return contiguous_data.is_contiguous() 2125 return NotImplemented 2126 2127 class ExampleTensor3(torch.Tensor): 2128 @staticmethod 2129 def __new__(cls, data, wrapper): 2130 return TestPythonDispatch.subclass_helper( 2131 cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2132 ) 2133 2134 @classmethod 2135 def __torch_dispatch__(cls, func, types, args, kwargs): 2136 if func.overloadpacket == torch.ops.aten.is_contiguous: 2137 return not_contiguous_data.is_contiguous() 2138 return NotImplemented 2139 2140 err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'" 2141 e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass) 2142 with self.assertRaisesRegex(TypeError, err_msg): 2143 e.is_contiguous() 2144 with self.assertRaisesRegex(TypeError, err_msg): 2145 e.contiguous() 2146 2147 e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass) 2148 self.assertEqual(e.is_contiguous(), True) 2149 e.contiguous() # this will just return the original TensorImpl since is_contiguous = True 2150 2151 err_msg = "Multiple dispatch failed for" 2152 e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass) 2153 self.assertEqual(e.is_contiguous(), False) 2154 with self.assertRaisesRegex(TypeError, err_msg): 2155 e.contiguous() 2156 2157 def test_fancy_strides(self): 2158 calls = [] 2159 2160 class ExampleTensor(torch.Tensor): 2161 @staticmethod 2162 def __new__(cls, data): 2163 return TestPythonDispatch.subclass_helper( 2164 cls, data, False, dispatch_sizes_strides_policy="strides" 2165 ) 2166 2167 @classmethod 2168 def __torch_dispatch__(cls, func, types, args, kwargs): 2169 if func in [ 2170 torch.ops.aten.is_contiguous.default, 2171 torch.ops.aten.is_contiguous.memory_format, 2172 torch.ops.aten.is_strides_like_format.default, 2173 torch.ops.aten.is_non_overlapping_and_dense.default, 2174 torch.ops.aten.stride.default, 2175 ]: 2176 calls.append((func, list(args)[1:])) 2177 return None 2178 with no_dispatch(): 2179 return func(*args, **kwargs) 2180 2181 e = ExampleTensor(torch.randn(2, 2)) 2182 self.assertFalse(e.is_contiguous(memory_format=torch.channels_last)) 2183 self.assertEqual( 2184 calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])] 2185 ) 2186 calls.clear() 2187 self.assertFalse( 2188 torch.ops.aten.is_strides_like_format.default(e, torch.channels_last) 2189 ) 2190 self.assertEqual( 2191 calls, 2192 [(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])], 2193 ) 2194 calls.clear() 2195 self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e)) 2196 self.assertEqual( 2197 calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])] 2198 ) 2199 2200 def test_device_slowpath(self): 2201 for use_wrapper_subclass in [True]: 2202 2203 class ExampleTensor1(torch.Tensor): 2204 @staticmethod 2205 def __new__(cls, data, wrapper): 2206 return TestPythonDispatch.subclass_helper( 2207 cls, data, wrapper, dispatch_device=True 2208 ) 2209 2210 @classmethod 2211 def __torch_dispatch__(cls, func, types, args, kwargs): 2212 return NotImplemented 2213 2214 class ExampleTensor2(torch.Tensor): 2215 @staticmethod 2216 def __new__(cls, data, wrapper): 2217 return TestPythonDispatch.subclass_helper( 2218 cls, data, wrapper, dispatch_device=True 2219 ) 2220 2221 @classmethod 2222 def __torch_dispatch__(cls, func, types, args, kwargs): 2223 if func.overloadpacket == torch.ops.prim.device: 2224 return torch.device("meta") 2225 return NotImplemented 2226 2227 class ExampleTensor3(torch.Tensor): 2228 @staticmethod 2229 def __new__(cls, data, wrapper): 2230 return TestPythonDispatch.subclass_helper( 2231 cls, data, wrapper, dispatch_device=True 2232 ) 2233 2234 @classmethod 2235 def __torch_dispatch__(cls, func, types, args, kwargs): 2236 if func.overloadpacket == torch.ops.prim.device: 2237 return torch.device("meta") 2238 return NotImplemented 2239 2240 err_msg = "Multiple dispatch failed for 'torch.ops.prim.device'" 2241 with self.assertRaisesRegex(TypeError, err_msg): 2242 e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass) 2243 e.device() 2244 2245 ten = torch.rand([1]) 2246 e = ExampleTensor2(torch.randn(3, 3, device="cpu"), use_wrapper_subclass) 2247 self.assertEqual(e.device.type, "meta") 2248 self.assertEqual(ten.type_as(e).device.type, "meta") 2249 2250 e = ExampleTensor3(torch.randn(3, 3, device="cpu"), use_wrapper_subclass) 2251 self.assertEqual(e.device.type, "meta") 2252 self.assertEqual(ten.type_as(e).device.type, "meta") 2253 2254 def test_dim_slowpath(self): 2255 data = torch.randn(3, 3) 2256 2257 for use_wrapper_subclass in [True, False]: 2258 2259 class DimNotImplementedTensor(torch.Tensor): 2260 @staticmethod 2261 def __new__(cls, data, wrapper): 2262 return TestPythonDispatch.subclass_helper( 2263 cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2264 ) 2265 2266 @classmethod 2267 def __torch_dispatch__(cls, func, types, args, kwargs): 2268 return NotImplemented 2269 2270 class DimImplementedTensor(torch.Tensor): 2271 @staticmethod 2272 def __new__(cls, data, wrapper): 2273 return TestPythonDispatch.subclass_helper( 2274 cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2275 ) 2276 2277 @classmethod 2278 def __torch_dispatch__(cls, func, types, args, kwargs): 2279 if func.overloadpacket == torch.ops.aten.dim: 2280 return data.dim() 2281 return NotImplemented 2282 2283 err_msg = "Multiple dispatch failed for 'torch.ops.aten.dim'" 2284 e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass) 2285 with self.assertRaisesRegex(TypeError, err_msg): 2286 e.dim() 2287 2288 t = DimImplementedTensor(torch.randn(3, 3), use_wrapper_subclass) 2289 self.assertEqual(t.dim(), 2) 2290 2291 def test_maybe_tuple_bug(self): 2292 class T(torch.Tensor): 2293 @classmethod 2294 def __torch_function__(cls, *args, **kwargs): 2295 pass 2296 2297 a = torch.rand(3) 2298 2299 a[[T(), T()]] 2300 2301 def test_standard_is_not_subclass(self): 2302 # https://github.com/pytorch/pytorch/issues/79079 2303 self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0))) 2304 2305 def test_sym_sizes_strides_slow_path(self): 2306 class TestTensor(torch.Tensor): 2307 @staticmethod 2308 def __new__(cls, *args, **kwargs): 2309 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 2310 cls, (0,), dispatch_sizes_strides_policy="sizes" 2311 ) 2312 return r 2313 2314 @classmethod 2315 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 2316 if func in ( 2317 torch.ops.aten.sym_size.default, 2318 torch.ops.aten.sym_stride.default, 2319 ): 2320 from torch._dynamo.source import ConstantSource 2321 from torch.fx.experimental.symbolic_shapes import ( 2322 DimDynamic, 2323 ShapeEnv, 2324 ) 2325 2326 shape_env = ShapeEnv() 2327 si = shape_env.create_symintnode( 2328 shape_env.create_symbol( 2329 123, 2330 source=ConstantSource("abc"), 2331 dynamic_dim=DimDynamic.DUCK, 2332 constraint_dim=None, 2333 ), 2334 hint=123, 2335 ) 2336 return (si,) 2337 2338 t = TestTensor() 2339 si = t.size()[0] 2340 self.assertIsInstance(si, torch.SymInt) 2341 si = t.stride()[0] 2342 self.assertIsInstance(si, torch.SymInt) 2343 2344 def test_strides_slow_path(self): 2345 for use_wrapper_subclass in [True, False]: 2346 2347 class StridesNotImplemented(torch.Tensor): 2348 @staticmethod 2349 def __new__(cls, data, wrapper): 2350 return TestPythonDispatch.subclass_helper( 2351 cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2352 ) 2353 2354 @classmethod 2355 def __torch_dispatch__(cls, func, types, args, kwargs): 2356 return NotImplemented 2357 2358 class StridesCustomReturn(torch.Tensor): 2359 @staticmethod 2360 def __new__(cls, data, wrapper): 2361 return TestPythonDispatch.subclass_helper( 2362 cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2363 ) 2364 2365 @classmethod 2366 def __torch_dispatch__(cls, func, types, args, kwargs): 2367 if func == torch.ops.aten.sym_stride.default: 2368 return (4, 2) 2369 return NotImplemented 2370 2371 class StridesDefaultReturn(torch.Tensor): 2372 @staticmethod 2373 def __new__(cls, data, wrapper): 2374 return TestPythonDispatch.subclass_helper( 2375 cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2376 ) 2377 2378 @classmethod 2379 def __torch_dispatch__(cls, func, types, args, kwargs): 2380 if func == torch.ops.aten.sym_stride.default: 2381 return None 2382 return NotImplemented 2383 2384 err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_stride'" 2385 e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass) 2386 with self.assertRaisesRegex(TypeError, err_msg): 2387 e.stride() 2388 2389 e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass) 2390 self.assertEqual(e.stride(), (4, 2)) 2391 2392 e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass) 2393 self.assertEqual(e.stride(), (2, 1)) 2394 2395 def test_sizes_slow_path(self): 2396 for use_wrapper_subclass in [True, False]: 2397 data = torch.randn(6, 2) 2398 2399 class SizesNotImplemented(torch.Tensor): 2400 @staticmethod 2401 def __new__(cls, data, wrapper): 2402 return TestPythonDispatch.subclass_helper( 2403 cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2404 ) 2405 2406 @classmethod 2407 def __torch_dispatch__(cls, func, types, args, kwargs): 2408 if func.overloadpacket == torch.ops.aten.dim: 2409 return data.dim() 2410 return NotImplemented 2411 2412 class SizesCustomReturn(torch.Tensor): 2413 @staticmethod 2414 def __new__(cls, data, wrapper): 2415 return TestPythonDispatch.subclass_helper( 2416 cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2417 ) 2418 2419 @classmethod 2420 def __torch_dispatch__(cls, func, types, args, kwargs): 2421 if func.overloadpacket == torch.ops.aten.dim: 2422 return data.dim() 2423 if func.overloadpacket == torch.ops.aten.sym_size: 2424 return (5, 3) 2425 return NotImplemented 2426 2427 class SizesDefaultReturn(torch.Tensor): 2428 @staticmethod 2429 def __new__(cls, data, wrapper): 2430 return TestPythonDispatch.subclass_helper( 2431 cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2432 ) 2433 2434 @classmethod 2435 def __torch_dispatch__(cls, func, types, args, kwargs): 2436 if func.overloadpacket == torch.ops.aten.dim: 2437 return data.dim() 2438 if func.overloadpacket == torch.ops.aten.sym_size: 2439 return None 2440 return NotImplemented 2441 2442 err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_size'" 2443 e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass) 2444 with self.assertRaisesRegex(TypeError, err_msg): 2445 e.size() 2446 2447 e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass) 2448 self.assertEqual(e.size(), (5, 3)) 2449 2450 e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass) 2451 self.assertEqual(e.size(), (4, 2)) 2452 2453 def test_custom_size_policy_dynamic_shapes(self): 2454 data = torch.randn(6, 2) 2455 2456 class CustomSizeDynamicShapesTensor(torch.Tensor): 2457 @staticmethod 2458 def __new__(cls, inner): 2459 return torch.Tensor._make_wrapper_subclass( 2460 # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. 2461 # Calling the overload that has kwargs causes us to go down the first overload path, 2462 # which will **always** specialize sizes. 2463 # We should probably eventually fix this so that the first overload can just handle dynamic shapes. 2464 cls, 2465 inner.size(), 2466 inner.stride(), 2467 None, 2468 None, 2469 inner.dtype, 2470 inner.layout, 2471 inner.device, 2472 False, 2473 inner.requires_grad, 2474 "sizes", 2475 ) 2476 2477 def __init__(self, inner): 2478 self.inner = inner 2479 2480 @classmethod 2481 def __torch_dispatch__(cls, func, types, args, kwargs): 2482 if func == torch.ops.aten.sym_size.default: 2483 return args[0].inner.shape 2484 if func == torch.ops.aten.sym_stride.default: 2485 return args[0].inner.shape 2486 return NotImplemented 2487 2488 x = torch.ones(2, 2) 2489 2490 def trace_fn(x): 2491 x_wrapper = CustomSizeDynamicShapesTensor(x) 2492 return x_wrapper.size(), x_wrapper.stride() 2493 2494 fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x) 2495 self.assertExpectedInline( 2496 fx_g.code.strip(), 2497 """\ 2498def forward(self, x_1): 2499 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 2500 sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None 2501 return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""", 2502 ) 2503 2504 def test_data_ptr_respects_numel_slow_path(self): 2505 data = torch.randn(6, 2) 2506 2507 class NumelDefaultReturn(torch.Tensor): 2508 @staticmethod 2509 def __new__(cls, data, wrapper): 2510 return TestPythonDispatch.subclass_helper( 2511 cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2512 ) 2513 2514 @classmethod 2515 def __torch_dispatch__(cls, func, types, args, kwargs): 2516 if func.overloadpacket == torch.ops.aten.dim: 2517 return data.dim() 2518 if func.overloadpacket == torch.ops.aten.numel: 2519 numel_called[0] = True 2520 return None 2521 return NotImplemented 2522 2523 for use_wrapper_subclass in (False, True): 2524 numel_called = [False] 2525 e = NumelDefaultReturn(torch.randn(2, 2), use_wrapper_subclass) 2526 e.data_ptr() 2527 self.assertTrue(numel_called[0]) 2528 2529 def test_layout_slow_path(self): 2530 for use_wrapper_subclass in [True, False]: 2531 data = torch.randn(6, 2) 2532 2533 class LayoutNotImplemented(torch.Tensor): 2534 @staticmethod 2535 def __new__(cls, data, wrapper): 2536 return TestPythonDispatch.subclass_helper( 2537 cls, data, wrapper, dispatch_layout=True 2538 ) 2539 2540 @classmethod 2541 def __torch_dispatch__(cls, func, types, args, kwargs): 2542 return NotImplemented 2543 2544 class LayoutCustomReturn(torch.Tensor): 2545 @staticmethod 2546 def __new__(cls, data, wrapper): 2547 return TestPythonDispatch.subclass_helper( 2548 cls, data, wrapper, dispatch_layout=True 2549 ) 2550 2551 @classmethod 2552 def __torch_dispatch__(cls, func, types, args, kwargs): 2553 if func.overloadpacket == torch.ops.prim.layout: 2554 return torch.sparse_csr 2555 return NotImplemented 2556 2557 class LayoutDefaultReturn(torch.Tensor): 2558 @staticmethod 2559 def __new__(cls, data, wrapper): 2560 return TestPythonDispatch.subclass_helper( 2561 cls, data, wrapper, dispatch_layout=True 2562 ) 2563 2564 @classmethod 2565 def __torch_dispatch__(cls, func, types, args, kwargs): 2566 if func.overloadpacket == torch.ops.prim.layout: 2567 return data.layout 2568 return NotImplemented 2569 2570 err_msg = "Multiple dispatch failed for 'torch.ops.prim.layout'" 2571 e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass) 2572 with self.assertRaisesRegex(TypeError, err_msg): 2573 e.layout 2574 2575 e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass) 2576 self.assertEqual(e.layout, torch.sparse_csr) 2577 2578 e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass) 2579 self.assertEqual(e.layout, torch.strided) 2580 2581 2582class TestPythonDispatcher(TestCase): 2583 def test_basic(self): 2584 x = torch.randn(2, requires_grad=True) 2585 r = torch._C._EnablePythonDispatcher() 2586 torch.add(x, x) 2587 2588 def test_lstsq(self): 2589 a = torch.randn(4, 3) 2590 b = torch.rand(4, 3) 2591 expected_shape = torch.linalg.lstsq(a, b).solution.shape 2592 r = torch._C._EnablePythonDispatcher() 2593 python_disp_shape = torch.linalg.lstsq(a, b).solution.shape 2594 self.assertEqual(expected_shape, python_disp_shape) 2595 2596 2597class TestWrapperSubclassAliasing(TestCase): 2598 def _test_wrapper_subclass_aliasing(self, op, args, kwargs): 2599 def to_subclass(t: torch.Tensor): 2600 return TwoTensor(t, t.clone()) 2601 2602 result_ref = op(*args, **kwargs) 2603 2604 args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args) 2605 kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs) 2606 2607 result_test = op(*args_subclass, **kwargs_subclass) 2608 2609 args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs) 2610 args_ref_flat_tensors = [ 2611 x for x in args_ref_flat if isinstance(x, torch.Tensor) 2612 ] 2613 2614 args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass)) 2615 args_test_flat_tensors = [ 2616 x for x in args_test_flat if isinstance(x, torch.Tensor) 2617 ] 2618 2619 result_ref_flat = pytree.tree_leaves(result_ref) 2620 result_ref_flat_tensors = [ 2621 x for x in result_ref_flat if isinstance(x, torch.Tensor) 2622 ] 2623 2624 result_test_flat = pytree.tree_leaves(result_test) 2625 result_test_flat_tensors = [ 2626 x for x in result_test_flat if isinstance(x, torch.Tensor) 2627 ] 2628 2629 for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors): 2630 for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors): 2631 out_is_inpt = o_ref is a_ref 2632 if out_is_inpt: 2633 self.assertTrue(o_test is a_test) 2634 2635 out_aliases_inpt = StorageWeakRef( 2636 o_ref.untyped_storage() 2637 ) == StorageWeakRef(a_ref.untyped_storage()) 2638 if out_aliases_inpt: 2639 self.assertTrue( 2640 StorageWeakRef(o_test.untyped_storage()) 2641 == StorageWeakRef(a_test.untyped_storage()) 2642 ) 2643 else: 2644 self.assertFalse( 2645 StorageWeakRef(o_test.untyped_storage()) 2646 == StorageWeakRef(a_test.untyped_storage()) 2647 ) 2648 2649 # This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`, 2650 # a util for wrapper subclasses to promise correct aliasing behavior. 2651 # It's probably overkill to test every OpInfo, 2652 # so I picked a sampling of ops with representative schemas. 2653 @ops( 2654 [ 2655 op 2656 for op in op_db 2657 if op.name 2658 in [ 2659 "mul", # out-of-place 2660 "cat", # out-of-place (TensorList input) 2661 "index", # out-of-place (Optional TensorList input) 2662 "mul_", # inplace 2663 "view", # view 2664 "t_", # inplace-view 2665 "split", # view (multi-return) 2666 "native_batch_norm", # mutable op (returns outputs and mutates some inputs) 2667 ] 2668 ], 2669 allowed_dtypes=(torch.float,), 2670 ) 2671 def test_wrapper_subclass_aliasing(self, device, dtype, op): 2672 samples = op.sample_inputs(device, dtype) 2673 sample = first_sample(self, samples) 2674 args = (sample.input, *sample.args) 2675 kwargs = sample.kwargs 2676 self._test_wrapper_subclass_aliasing(op, args, kwargs) 2677 2678 @ops(custom_op_db, allowed_dtypes=(torch.float,)) 2679 def test_wrapper_subclass_aliasing_custom(self, device, dtype, op): 2680 samples = op.sample_inputs(device, dtype) 2681 sample = first_sample(self, samples) 2682 args = (sample.input, *sample.args) 2683 kwargs = sample.kwargs 2684 self._test_wrapper_subclass_aliasing(op, args, kwargs) 2685 2686 def test_wrapper_subclass_aliasing_conv2d(self, device): 2687 args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4)) 2688 kwargs = {} 2689 # conv2d has a default arg 'int[2] strides=0', 2690 # which torchscript expands into 'int[2] strides=[0, 0]' 2691 # Make sure that _return_and_correct_aliasing can handle this case 2692 # (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch) 2693 with torch.inference_mode(): 2694 self._test_wrapper_subclass_aliasing( 2695 torch.ops.aten.conv2d.default, args, kwargs 2696 ) 2697 2698 def test_wrapper_subclass_aliasing_out_op(self, device): 2699 # Make sure that _return_and_correct_aliasing can handle kwargs w mutable tensors 2700 args = (torch.ones(4), torch.ones(4)) 2701 kwargs = {"out": torch.empty(4)} 2702 self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs) 2703 2704 2705instantiate_device_type_tests(TestWrapperSubclassAliasing, globals()) 2706 2707if __name__ == "__main__": 2708 run_tests() 2709