1# Owner(s): ["module: custom-operators"] 2 3import collections 4import itertools 5import os 6import re 7import subprocess 8import sys 9import typing 10import unittest 11from typing import * # noqa: F403 12 13import numpy as np 14 15import torch._custom_ops as custom_ops 16import torch.testing._internal.optests as optests 17import torch.utils._pytree as pytree 18import torch.utils.cpp_extension 19from functorch import make_fx 20from torch import Tensor 21from torch._custom_op.impl import CustomOp, infer_schema 22from torch._library.infer_schema import tuple_to_list 23from torch._utils_internal import get_file_path_2 24from torch.testing._internal import custom_op_db 25from torch.testing._internal.common_cuda import TEST_CUDA 26from torch.testing._internal.common_device_type import ( 27 instantiate_device_type_tests, 28 OpDTypes, 29 ops, 30) 31from torch.testing._internal.common_utils import ( 32 instantiate_parametrized_tests, 33 IS_WINDOWS, 34 parametrize, 35 run_tests, 36 skipIfTorchDynamo, 37 subtest, 38 TestCase, 39) 40from torch.testing._internal.custom_op_db import numpy_nonzero 41 42 43# Shadowed by `torch.testing._internal.common_utils.custom_op` 44from torch._custom_op.impl import custom_op # usort: skip 45 46 47def requires_compile(fun): 48 fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun) 49 return fun 50 51 52class CustomOpTestCaseBase(TestCase): 53 test_ns = "_test_custom_op" 54 55 def setUp(self): 56 super().setUp() 57 self.libraries = [] 58 59 def tearDown(self): 60 super().tearDown() 61 import torch._custom_op 62 63 keys = list(torch._custom_op.impl.global_registry.keys()) 64 for key in keys: 65 if not key.startswith(f"{self.test_ns}::"): 66 continue 67 torch._custom_op.impl.global_registry[key]._destroy() 68 if hasattr(torch.ops, self.test_ns): 69 delattr(torch.ops, self.test_ns) 70 for lib in self.libraries: 71 lib._destroy() 72 del self.libraries 73 74 def ns(self): 75 return getattr(torch.ops, self.test_ns) 76 77 def lib(self): 78 result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901 79 self.libraries.append(result) 80 return result 81 82 def get_op(self, qualname): 83 return torch._custom_op.impl.get_op(qualname) 84 85 86@requires_compile 87class TestCustomOpTesting(CustomOpTestCaseBase): 88 @parametrize("check_gradients", (False, "auto")) 89 @parametrize("dynamic", (True, False)) 90 def test_aot_autograd_check_degenerate_cases( 91 self, device, dynamic, check_gradients 92 ): 93 def simple(x): 94 return x.clone() 95 96 # Should not raise 97 x = torch.randn(3, device=device) 98 optests.aot_autograd_check( 99 simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients 100 ) 101 102 def outputs_dont_require_grad(x): 103 return x.detach() 104 105 # Should not raise 106 y = torch.randn(3, device=device, requires_grad=True) 107 optests.aot_autograd_check( 108 simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients 109 ) 110 111 def no_outputs(x): 112 return x.detach() 113 114 # Should not raise 115 x = torch.randn(3, device=device, requires_grad=True) 116 y = torch.randn(3, device=device, requires_grad=False) 117 optests.aot_autograd_check( 118 no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients 119 ) 120 optests.aot_autograd_check( 121 no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients 122 ) 123 124 def test_incorrect_schema_mutation(self, device): 125 lib = self.lib() 126 lib.define("foo(Tensor x) -> Tensor") 127 op = self.ns().foo.default 128 129 class Foo(torch.autograd.Function): 130 @staticmethod 131 def forward(ctx, x): 132 guard = torch._C._AutoDispatchBelowAutograd() 133 try: 134 return op(x) 135 finally: 136 del guard 137 138 @staticmethod 139 def backward(ctx, gx): 140 return gx 141 142 def foo_impl(x): 143 x.sin_() 144 return x.clone() 145 146 lib.impl("foo", Foo.apply, "Autograd") 147 lib.impl("foo", foo_impl, "CPU") 148 lib.impl("foo", foo_impl, "CUDA") 149 150 x = torch.tensor(3.14159 / 3, requires_grad=True, device=device) 151 with self.assertRaisesRegex( 152 optests.OpCheckError, "Argument x is not defined as mutable but was mutated" 153 ): 154 torch.library.opcheck(op, (x,), {}) 155 156 def test_incorrect_schema_view(self, device): 157 lib = self.lib() 158 lib.define("foo(Tensor x) -> Tensor") 159 op = self.ns().foo.default 160 161 class Foo(torch.autograd.Function): 162 @staticmethod 163 def forward(ctx, x): 164 # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python 165 with torch._C._AutoDispatchBelowAutograd(): 166 with torch._C._ExcludeDispatchKeyGuard( 167 torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) 168 ): 169 return op(x) 170 171 @staticmethod 172 def backward(ctx, gx): 173 return gx 174 175 def foo_impl(x): 176 return x.view_as(x) 177 178 def foo_meta(x): 179 return x.view_as(x) 180 181 lib.impl("foo", Foo.apply, "Autograd") 182 lib.impl("foo", foo_impl, "CPU") 183 lib.impl("foo", foo_meta, "Meta") 184 185 x = torch.tensor(3.14159 / 3, requires_grad=True) 186 with self.assertRaisesRegex( 187 optests.OpCheckError, 188 "Argument x is not defined to alias output but was aliasing", 189 ): 190 torch.library.opcheck(op, (x,), {}) 191 192 def test_missing_abstract_impl(self, device): 193 lib = self.lib() 194 lib.define("foo(Tensor x) -> Tensor") 195 op = self.ns().foo.default 196 197 class Foo(torch.autograd.Function): 198 @staticmethod 199 def forward(ctx, x): 200 with torch._C._AutoDispatchBelowAutograd(): 201 return op(x) 202 203 @staticmethod 204 def backward(ctx, gx): 205 return 2 * gx 206 207 def foo_impl(x): 208 return torch.tensor(x.cpu().numpy() ** 2, device=x.device) 209 210 lib.impl("foo", Foo.apply, "Autograd") 211 lib.impl("foo", foo_impl, "CPU") 212 lib.impl("foo", foo_impl, "CUDA") 213 214 x = torch.tensor([0, 1.0], requires_grad=True) 215 with self.assertRaisesRegex( 216 optests.OpCheckError, 217 "_test_custom_op.foo.default", 218 ): 219 torch.library.opcheck(op, (x,), {}) 220 221 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 222 def test_incorrect_abstract_impl(self, device): 223 lib = self.lib() 224 lib.define("foo(Tensor x) -> Tensor") 225 op = self.ns().foo.default 226 227 class Foo(torch.autograd.Function): 228 @staticmethod 229 def forward(ctx, x): 230 # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python 231 guard = torch._C._AutoDispatchBelowAutograd() 232 guard2 = torch._C.ExcludeDispatchKeyGuard( 233 torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) 234 ) 235 try: 236 return op(x) 237 finally: 238 del guard 239 del guard2 240 241 @staticmethod 242 def backward(ctx, gx): 243 return gx 244 245 def foo_impl(x): 246 return x**2 247 248 def foo_meta(x): 249 return x.unsqueeze(1) ** 2 250 251 lib.impl("foo", Foo.apply, "Autograd") 252 lib.impl("foo", foo_impl, "CPU") 253 lib.impl("foo", foo_impl, "CUDA") 254 lib.impl("foo", foo_meta, "Meta") 255 256 x = torch.tensor([0, 1.0], requires_grad=True) 257 with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"): 258 torch.library.opcheck(op, (x,), {}) 259 260 def test_missing_functionalization(self, device): 261 lib = self.lib() 262 lib.define("foo(Tensor(a!) x) -> Tensor(a!)") 263 op = self.ns().foo.default 264 265 class Foo(torch.autograd.Function): 266 @staticmethod 267 def forward(ctx, x): 268 ctx.mark_dirty(x) 269 with torch._C._AutoDispatchBelowAutograd(): 270 return op(x) 271 272 @staticmethod 273 def backward(ctx, gx): 274 return gx 275 276 def foo_impl(x): 277 return x.sin_() 278 279 def foo_meta(x): 280 return x 281 282 lib.impl("foo", Foo.apply, "Autograd") 283 lib.impl("foo", foo_impl, "CPU") 284 lib.impl("foo", foo_impl, "CUDA") 285 lib.impl("foo", foo_meta, "Meta") 286 287 x = torch.tensor([0, 1.0]) 288 y = x.clone() 289 with self.assertRaisesRegex( 290 optests.OpCheckError, 291 "We only support functionalizing operators whose outputs do not have alias annotations", 292 ): 293 torch.library.opcheck(op, (y,), {}) 294 295 def test_autograd_registered_at_backend(self, device): 296 lib = self.lib() 297 lib.define("foo(Tensor x) -> Tensor") 298 op = self.ns().foo.default 299 300 class Foo(torch.autograd.Function): 301 @staticmethod 302 def forward(ctx, x): 303 return x.clone() 304 305 @staticmethod 306 def backward(ctx, gx): 307 return gx * 0.5 308 309 lib.impl("foo", Foo.apply, "CPU") 310 lib.impl("foo", Foo.apply, "CUDA") 311 lib.impl("foo", lambda x: x.clone(), "Meta") 312 313 x = torch.randn([], requires_grad=True) 314 315 with self.assertRaisesRegex( 316 torch.testing._internal.optests.OpCheckError, 317 "does not have an autograd kernel", 318 ): 319 torch.library.opcheck(op, (x,), {}) 320 321 # I'm not sure why this is necessary 322 del lib 323 324 def test_global_state_mutation(self, device): 325 lib = self.lib() 326 lib.define("foo(Tensor x) -> Tensor") 327 op = self.ns().foo.default 328 329 class Foo(torch.autograd.Function): 330 invoked = 0 331 332 @staticmethod 333 def forward(ctx, x): 334 Foo.invoked += 1 335 return x.clone() * Foo.invoked 336 337 @staticmethod 338 def backward(ctx, gx): 339 return gx 340 341 lib.impl("foo", Foo.apply, "CompositeImplicitAutograd") 342 343 x = torch.tensor(3.14159 / 3, requires_grad=True) 344 with self.assertRaisesRegex( 345 optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd" 346 ): 347 torch.library.opcheck(op, (x,), {}) 348 349 @ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one) 350 def test_opcheck_opinfo(self, device, dtype, op): 351 for sample_input in op.sample_inputs( 352 device, dtype, requires_grad=op.supports_autograd 353 ): 354 args = [sample_input.input] + list(sample_input.args) 355 kwargs = sample_input.kwargs 356 torch.library.opcheck(op.op, args, kwargs) 357 358 def test_opcheck_fails_basic(self, device): 359 @custom_op(f"{self.test_ns}::foo") 360 def foo(x: torch.Tensor) -> torch.Tensor: ... 361 362 @foo.impl(["cpu", "cuda"]) 363 def foo_impl(x): 364 return x.sum() 365 366 x = torch.randn(3, device=device, requires_grad=True) 367 # Triggers the CustomOp autograd NYI error 368 with self.assertRaisesRegex( 369 optests.OpCheckError, "Autograd has not been implemented for operator" 370 ): 371 torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {}) 372 373 def test_autograd_registration_check_autograd_kernel(self, device): 374 lib = self.lib() 375 lib.define("foo(Tensor x) -> Tensor") 376 op = self.ns().foo.default 377 378 class Foo(torch.autograd.Function): 379 @staticmethod 380 def forward(ctx, x): 381 with torch._C._AutoDispatchBelowAutograd(): 382 return op(x) 383 384 @staticmethod 385 def backward(ctx, gx): 386 return gx 387 388 def foo_impl(x): 389 return x.sin() 390 391 lib.impl("foo", Foo.apply, "Autograd") 392 lib.impl("foo", foo_impl, "CPU") 393 lib.impl("foo", foo_impl, "CUDA") 394 395 x = torch.randn(3, requires_grad=True, device=device) 396 # Should not raise 397 optests.autograd_registration_check(op, (x,), {}) 398 399 def test_autograd_registration_check_compositeimplicitautograd(self, device): 400 lib = self.lib() 401 lib.define("foo(Tensor x) -> Tensor") 402 op = self.ns().foo.default 403 404 def foo_impl(x): 405 return x.sin().cos() 406 407 lib.impl("foo", foo_impl, "CompositeImplicitAutograd") 408 409 x = torch.randn(3, requires_grad=True, device=device) 410 # Should not raise 411 optests.autograd_registration_check(op, (x,), {}) 412 413 def test_autograd_registration_check_incorrect_composite(self, device): 414 lib = self.lib() 415 lib.define("foo(Tensor x) -> Tensor") 416 op = self.ns().foo.default 417 418 def foo_impl(x): 419 return x.sin().cos() 420 421 lib.impl("foo", foo_impl, "CompositeExplicitAutograd") 422 423 x = torch.randn(3, requires_grad=True, device=device) 424 with self.assertRaisesRegex(AssertionError, "incorrectly registered"): 425 optests.autograd_registration_check(op, (x,), {}) 426 427 def test_autograd_registration_check_incorrect(self, device): 428 lib = self.lib() 429 lib.define("foo(Tensor x) -> Tensor") 430 op = self.ns().foo.default 431 432 class Foo(torch.autograd.Function): 433 @staticmethod 434 def forward(ctx, x): 435 return torch.sin(x) 436 437 @staticmethod 438 def backward(ctx, gx): 439 return gx 440 441 lib.impl("foo", Foo.apply, "CPU") 442 lib.impl("foo", Foo.apply, "CUDA") 443 444 x = torch.randn(3, requires_grad=True, device=device) 445 with self.assertRaisesRegex(AssertionError, "incorrectly registered"): 446 optests.autograd_registration_check(op, (x,), {}) 447 448 def test_assert_raises_regex(self, device): 449 from torch.testing._internal.optests.aot_autograd import assert_raises_regex 450 451 with assert_raises_regex(RuntimeError, "c"): 452 raise RuntimeError("abcd") 453 with assert_raises_regex(RuntimeError, "c.*"): 454 raise RuntimeError("abcd") 455 with self.assertRaisesRegex(AssertionError, "instead got"): 456 with assert_raises_regex(RuntimeError, "c.*"): 457 raise ValueError("abcd") 458 with self.assertRaisesRegex(AssertionError, "Expected exception"): 459 with assert_raises_regex(RuntimeError, "c.*"): 460 pass 461 with self.assertRaisesRegex(AssertionError, "to match regex"): 462 with assert_raises_regex(RuntimeError, "f"): 463 raise RuntimeError("abcd") 464 465 466class TestCustomOp(CustomOpTestCaseBase): 467 test_ns = "_test_custom_op" 468 469 @requires_compile 470 def test_functionalize_error(self): 471 with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: 472 lib.define("foo(Tensor(a!) x) -> Tensor(a!)") 473 474 def foo(x): 475 return x.sin_() 476 477 lib.impl("foo", foo, "CompositeExplicitAutograd") 478 foo_op = self.get_op(f"{self.test_ns}::foo") 479 480 lib.define("bar(Tensor(a) x) -> Tensor(a)") 481 482 def bar(x): 483 return x.view(-1) 484 485 lib.impl("bar", bar, "CompositeExplicitAutograd") 486 bar_op = self.get_op(f"{self.test_ns}::bar") 487 488 msg = r".*We only support functionalizing operators whose outputs do not have alias annotations" 489 490 x = torch.randn(3) 491 492 @torch.compile(backend="aot_eager", fullgraph=True) 493 def f(x): 494 return foo_op(x) 495 496 @torch.compile(backend="aot_eager", fullgraph=True) 497 def g(x): 498 return bar_op(x) 499 500 with self.assertRaisesRegex(RuntimeError, msg): 501 f(x) 502 with self.assertRaisesRegex(RuntimeError, msg): 503 g(x) 504 505 def test_invalid_schemas(self): 506 # function schmea validation goes through torchgen, so this is just a 507 # basic test. 508 with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"): 509 custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(") 510 511 def test_invalid_qualname(self): 512 with self.assertRaisesRegex(ValueError, "overload"): 513 custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()") 514 515 def test_name_must_match(self): 516 with self.assertRaisesRegex(ValueError, "to have name"): 517 518 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 519 def baz(x: Tensor) -> Tensor: 520 raise NotImplementedError 521 522 def test_unsupported_schemas(self): 523 with self.assertRaisesRegex(ValueError, "only supports functional"): 524 custom_ops.custom_op( 525 f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)" 526 )(foo) 527 with self.assertRaisesRegex(ValueError, "only supports functional"): 528 custom_ops.custom_op( 529 f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)" 530 )(foo) 531 with self.assertRaisesRegex(ValueError, "only supports functional"): 532 custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")( 533 foo 534 ) 535 with self.assertRaisesRegex(ValueError, "self"): 536 custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")( 537 foo 538 ) 539 540 # Tests for the older custom_op API 541 def test_schema_matches_signature(self): 542 with self.assertRaisesRegex(ValueError, "signature to match"): 543 544 @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor") 545 def blah(x): 546 pass 547 548 with self.assertRaisesRegex(ValueError, "signature to match"): 549 550 @custom_op( 551 f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor" 552 ) 553 def blah2(x, y): 554 pass 555 556 with self.assertRaisesRegex(ValueError, "signature to match"): 557 558 @custom_op( 559 f"{TestCustomOp.test_ns}::blah3", 560 "(Tensor x, *, Tensor w, Tensor z) -> Tensor", 561 ) 562 def blah3(x, *, y, z): 563 pass 564 565 with self.assertRaisesRegex(ValueError, "signature to match"): 566 567 @custom_op( 568 f"{TestCustomOp.test_ns}::blah4", 569 "(Tensor x, *, Tensor z, Tensor y) -> Tensor", 570 ) 571 def blah4(x, *, y, z): 572 pass 573 574 with self.assertRaisesRegex(ValueError, "not supported"): 575 576 @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor") 577 def blah5(*args): 578 pass 579 580 with self.assertRaisesRegex(ValueError, "not supported"): 581 582 @custom_op( 583 f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor" 584 ) 585 def blah6(**kwargs): 586 pass 587 588 with self.assertRaisesRegex(ValueError, "default arguments"): 589 590 @custom_op( 591 f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor" 592 ) 593 def blah7(x=1, *, y): 594 pass 595 596 with self.assertRaisesRegex(ValueError, "default arguments"): 597 598 @custom_op( 599 f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor" 600 ) 601 def blah8(x, *, y=1): 602 pass 603 604 # kwonly-arg works 605 @custom_op( 606 f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor" 607 ) 608 def blah9(x, *, y): 609 pass 610 611 def test_infer_schema_no_return(self): 612 with self.assertRaisesRegex( 613 ValueError, "No return type annotation was provided. Please add one." 614 ): 615 616 @torch.library.custom_op("mylib::foo", mutates_args={}) 617 def foo(x: torch.Tensor, y: int): 618 return x * y 619 620 def test_infer_schema_supported(self): 621 def a(x: Tensor) -> Tensor: 622 return torch.empty([]) 623 624 self.assertExpectedInline( 625 infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor""" 626 ) 627 628 def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor: 629 return torch.empty([]) 630 631 self.assertExpectedInline( 632 infer_schema(kwonly1, mutates_args=()), 633 """(Tensor x, *, SymInt y, float z) -> Tensor""", 634 ) 635 636 def kwonly2(*, y: Tensor) -> Tensor: 637 return torch.empty([]) 638 639 self.assertExpectedInline( 640 infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor""" 641 ) 642 643 def b( 644 x: Tensor, 645 y: int, 646 z: bool, 647 a: float, 648 b: torch.dtype, 649 c: torch.device, 650 d: torch.types.Number, 651 ) -> Tuple[Tensor, int, float, bool]: 652 return torch.empty([]), 1, 0.1, True 653 654 self.assertExpectedInline( 655 infer_schema(b, mutates_args=()), 656 """(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""", 657 ) 658 659 def c( 660 x: Tensor, 661 y: Sequence[Tensor], 662 z: Optional[Tensor], 663 w: Sequence[Optional[Tensor]], 664 ) -> List[Tensor]: 665 return [torch.empty([])] 666 667 self.assertExpectedInline( 668 infer_schema(c, mutates_args=()), 669 """(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""", 670 ) 671 672 def d(x: Tensor) -> Tuple[List[Tensor], Tensor]: 673 return [torch.empty([])], torch.empty([]) 674 675 self.assertExpectedInline( 676 infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)""" 677 ) 678 679 def e() -> Tensor: 680 return torch.empty([]) 681 682 self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""") 683 684 def f(x: Tensor) -> None: 685 pass 686 687 self.assertExpectedInline( 688 infer_schema(f, mutates_args=()), """(Tensor x) -> ()""" 689 ) 690 691 def g( 692 x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]] 693 ) -> None: 694 pass 695 696 self.assertExpectedInline( 697 infer_schema(g, mutates_args=()), 698 """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""", 699 ) 700 701 self.assertExpectedInline( 702 infer_schema(g, mutates_args={"x", "w", "z"}), 703 """(Tensor(a0!) x, Tensor[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""", 704 ) 705 706 self.assertExpectedInline( 707 infer_schema(g, mutates_args="unknown"), 708 """(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""", 709 ) 710 711 def h( 712 x: Tensor, 713 a: Optional[int] = None, 714 b: float = 3.14, 715 c: bool = True, 716 d: int = 3, 717 e: str = "foo", 718 f: torch.dtype = torch.float, 719 g: torch.dtype = torch.float32, 720 h: torch.dtype = torch.int, 721 i: torch.device = torch.device("cpu:0"), 722 j: torch.device = "cpu", 723 ) -> None: 724 pass 725 726 self.assertExpectedInline( 727 infer_schema(h, mutates_args=()), 728 ( 729 """(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """ 730 """ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()""" 731 ), 732 ) 733 734 def foo_impl(x: torch.Tensor) -> torch.Tensor: 735 return x.sin() 736 737 schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={}) 738 self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor") 739 740 def test_infer_schema_unsupported(self): 741 with self.assertRaisesRegex(ValueError, "varargs"): 742 743 def foo(*args): 744 raise NotImplementedError 745 746 infer_schema(foo, mutates_args=()) 747 748 with self.assertRaisesRegex(ValueError, "varkwargs"): 749 750 def foo(**kwargs): 751 raise NotImplementedError 752 753 infer_schema(foo, mutates_args=()) 754 755 with self.assertRaisesRegex(ValueError, "must have a type annotation"): 756 757 def foo(x): 758 raise NotImplementedError 759 760 infer_schema(foo, mutates_args=()) 761 762 with self.assertRaisesRegex(ValueError, "unsupported"): 763 764 def foo(x: Tensor) -> Tuple[Tensor, ...]: 765 raise NotImplementedError 766 767 infer_schema(foo, mutates_args=()) 768 769 with self.assertRaisesRegex(ValueError, "can be mutated"): 770 771 def foo(x: Tensor, y: int) -> Tensor: 772 raise NotImplementedError 773 774 infer_schema(foo, mutates_args={"y"}) 775 776 def _generate_examples(self, typ): 777 if typ is int: 778 return [17] 779 if typ is float: 780 return [3.14] 781 if typ is bool: 782 return [True] 783 if typ is str: 784 return ["foo"] 785 if typ is torch.dtype: 786 return [torch.float32] 787 if typ is torch.device: 788 return [torch.device("cpu")] 789 if typ == torch.types.Number: 790 return [2.718] 791 if typ is torch.Tensor: 792 return [torch.tensor(3)] 793 if typ == Optional[torch.types.Number]: 794 return [None, 2.718] 795 origin = typing.get_origin(typ) 796 if origin is Union: 797 args = typing.get_args(typ) 798 assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None)) 799 elt = args[0] if args[1] is type(None) else args[1] 800 return self._generate_examples(elt) + [None] 801 if origin is list: 802 args = typing.get_args(typ) 803 assert len(args) == 1 804 elt = args[0] 805 return [ 806 self._generate_examples(elt), 807 self._generate_examples(elt), 808 self._generate_examples(elt), 809 ] 810 if origin is collections.abc.Sequence: 811 args = typing.get_args(typ) 812 assert len(args) == 1 813 examples = self._generate_examples(args[0]) 814 return list(itertools.product(examples, examples)) + [] 815 raise NotImplementedError( 816 f"testrunner cannot generate instanstance of type {typ}" 817 ) 818 819 def test_supported_return_types_single_return(self): 820 for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES: 821 for example in self._generate_examples(typ): 822 try: 823 824 @custom_ops.custom_op(f"{self.test_ns}::foo") 825 def foo(x: Tensor) -> typ: 826 raise NotImplementedError 827 828 @custom_ops.impl(f"{self.test_ns}::foo") 829 def foo_impl(x: Tensor) -> typ: 830 return example 831 832 op = self.get_op(f"{self.test_ns}::foo") 833 result = op(torch.randn([])) 834 self.assertEqual(result, example, msg=f"{typ} {example}") 835 finally: 836 custom_ops._destroy(f"{self.test_ns}::foo") 837 838 def test_supported_return_types_multi_return(self): 839 for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES: 840 for example in self._generate_examples(typ): 841 try: 842 843 @custom_ops.custom_op(f"{self.test_ns}::foo") 844 def foo(x: Tensor) -> Tuple[typ, typ]: 845 raise NotImplementedError 846 847 @custom_ops.impl(f"{self.test_ns}::foo") 848 def foo_impl(x: Tensor) -> Tuple[typ, typ]: 849 return (example, example) 850 851 op = self.get_op(f"{self.test_ns}::foo") 852 result = op(torch.randn([])) 853 expected = (example, example) 854 self.assertEqual(result, expected, msg=f"{typ} {example}") 855 finally: 856 custom_ops._destroy(f"{self.test_ns}::foo") 857 858 def test_supported_param_types(self): 859 for typ in torch._library.infer_schema.SUPPORTED_PARAM_TYPES: 860 861 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 862 def foo(x: Tensor, y: typ) -> Tensor: 863 raise NotImplementedError 864 865 yeet = None 866 867 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"]) 868 def foo_cpu(x, y): 869 nonlocal yeet 870 yeet = y 871 return x.clone() 872 873 try: 874 for example in self._generate_examples(typ): 875 op = self.get_op(f"{self.test_ns}::foo") 876 op(torch.randn([]), example) 877 self.assertEqual(yeet, example, msg=f"{typ} {example}") 878 yeet = None 879 finally: 880 custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 881 882 def test_sequences(self): 883 # Sequence[int] gets automagically turned into int[] in the schema. 884 # This test checks that we actually do support arbitrary sequence types. 885 class MySequence(collections.abc.Sequence): 886 def __init__(self) -> None: 887 self._container = [1, 2, 3] 888 889 def __getitem__(self, idx): 890 return self._container[idx] 891 892 def __len__(self): 893 return len(self._container) 894 895 @custom_ops.custom_op(f"{self.test_ns}::foo") 896 def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor: 897 raise NotImplementedError 898 899 called = 0 900 901 @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu") 902 def foo_cpu(x, sizes): 903 nonlocal called 904 called += 1 905 # Dispatcher will normalize the sequence type into a List 906 self.assertEqual(sizes, [1, 2, 3]) 907 return x.clone() 908 909 x = torch.randn([]) 910 seq = MySequence() 911 op = self.get_op(f"{self.test_ns}::foo") 912 op(x, seq) 913 self.assertEqual(called, 1) 914 915 def test_unsupported_param_types(self): 916 # Not comprehensive (it doesn't need to be), just a check that our mechanism works 917 with self.assertRaisesRegex(ValueError, "unsupported type"): 918 919 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 920 def foo(x: Tensor, y: List[Optional[int]]) -> Tensor: 921 raise NotImplementedError 922 923 del foo 924 925 with self.assertRaisesRegex(ValueError, "unsupported type"): 926 # int[N] in Dispatcher is a bit wild, so we don't try to support it. 927 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 928 def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: 929 raise NotImplementedError 930 931 del foo 932 933 with self.assertRaisesRegex(ValueError, r"For example, typing.List\[int\]"): 934 # test that we propose a correct and supported type. 935 @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={}) 936 def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: 937 raise NotImplementedError 938 939 del foo 940 941 with self.assertRaises(ValueError) as cm: 942 943 @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={}) 944 def foo(x: Tensor, y: Tuple[int, float]) -> Tensor: 945 raise NotImplementedError 946 947 del foo 948 949 self.assertNotIn("example", str(cm.exception), "") 950 951 with self.assertRaisesRegex(ValueError, "unsupported type"): 952 953 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 954 def foo(x: Tensor, y: Callable) -> Tensor: 955 raise NotImplementedError 956 957 del foo 958 959 def test_supported_schemas(self): 960 # All of these should already be tested by PyTorch codegen 961 # (we share the same mechanism), but here's a sanity check. 962 schemas = [ 963 "(Tensor x) -> Tensor", 964 "(Tensor x) -> Tensor y", 965 "(Tensor[] x) -> Tensor y", 966 "(Tensor x) -> (Tensor, Tensor)", 967 "(Tensor x) -> (Tensor y, Tensor z)", 968 "(Tensor x) -> (Tensor y, Tensor z)", 969 ] 970 other_schemas = [ 971 "(Tensor x, Tensor w) -> (Tensor y, Tensor z)", 972 "(Tensor x, Tensor w) -> (Tensor, Tensor)", 973 "(Tensor x, Tensor w) -> Tensor", 974 "(Tensor? x, Tensor w) -> Tensor", 975 "(Tensor? x, Tensor[] w) -> Tensor", 976 "(Tensor x, int[] w) -> Tensor", 977 "(Tensor x, SymInt[] w) -> Tensor", 978 "(Tensor x, Scalar w) -> Tensor", 979 "(Tensor x, float w) -> Tensor", 980 "(Tensor x, float? w) -> Tensor", 981 "(Tensor x, bool[] w) -> Tensor", 982 ] 983 984 for schema in schemas: 985 custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema) 986 custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 987 for schema in other_schemas: 988 custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema) 989 custom_ops._destroy(f"{TestCustomOp.test_ns}::bar") 990 991 def test_reserved_ns(self): 992 from torch._custom_op.impl import RESERVED_NS 993 994 for ns in RESERVED_NS: 995 with self.assertRaisesRegex(ValueError, "is a reserved namespace"): 996 custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor") 997 998 with self.assertRaisesRegex(ValueError, "is a reserved namespace"): 999 1000 @custom_ops.custom_op(f"{ns}::foo2") 1001 def foo2(x: torch.Tensor) -> torch.Tensor: 1002 raise NotImplementedError 1003 1004 def test_private_ctor(self): 1005 with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"): 1006 CustomOp(None, None, None, None, None) 1007 1008 def test_lifetime(self): 1009 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1010 def foo(x: torch.Tensor) -> torch.Tensor: 1011 raise NotImplementedError 1012 1013 custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo") 1014 1015 # We can't define an op multiple times, 1016 with self.assertRaisesRegex(RuntimeError, "multiple times"): 1017 1018 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1019 def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 1020 raise NotImplementedError 1021 1022 # Unless we delete the original op. 1023 custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1024 1025 # Smoke test 1026 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1027 def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 1028 raise NotImplementedError 1029 1030 custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1031 1032 def test_autograd_notimplemented(self): 1033 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1034 def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 1035 raise NotImplementedError 1036 1037 x = torch.randn(3, requires_grad=True) 1038 op = self.get_op(f"{self.test_ns}::foo") 1039 with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): 1040 op(x) 1041 custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1042 del foo 1043 1044 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1045 def foo(x: Sequence[torch.Tensor]) -> torch.Tensor: 1046 raise NotImplementedError 1047 1048 x = torch.randn(3, requires_grad=True) 1049 y = torch.randn(3) 1050 op = self.get_op(f"{self.test_ns}::foo") 1051 with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): 1052 op([y, x]) 1053 custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1054 del foo 1055 1056 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1057 def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1058 raise NotImplementedError 1059 1060 x = torch.randn(3, requires_grad=True) 1061 y = torch.randn(3) 1062 op = self.get_op(f"{self.test_ns}::foo") 1063 with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): 1064 op(y, x) 1065 custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") 1066 1067 def test_autograd_notimplemented_gradmode(self): 1068 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1069 def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1070 raise NotImplementedError 1071 1072 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1073 def foo_impl(x, y): 1074 return x * y 1075 1076 x = torch.randn(3, requires_grad=True) 1077 y = torch.randn(3) 1078 op = self.get_op(f"{self.test_ns}::foo") 1079 with torch.no_grad(): 1080 # Shouldn't raise, because we are in no_grad 1081 op(y, x) 1082 1083 def test_impl_cpu(self): 1084 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1085 def foo(x: torch.Tensor) -> torch.Tensor: 1086 raise NotImplementedError 1087 1088 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") 1089 def foo_cpu(x): 1090 return x.sin() 1091 1092 x = torch.randn(3) 1093 op = self.get_op(f"{self.test_ns}::foo") 1094 result = op(x) 1095 self.assertEqual(result, foo_cpu(x)) 1096 1097 def test_impl_invalid_devices(self): 1098 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1099 def foo(x: torch.Tensor) -> torch.Tensor: 1100 raise NotImplementedError 1101 1102 def foo_impl(x): 1103 return x.sin() 1104 1105 from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY 1106 1107 for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys(): 1108 # Smoke test: should not raise error 1109 custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)( 1110 foo_impl 1111 ) 1112 1113 # Not supported by this API: we can either support them in the future 1114 # or provide some other CustomOp.def_* function. This depends on how 1115 # common the use cases are. 1116 for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]: 1117 with self.assertRaisesRegex(ValueError, "we only support device_type"): 1118 custom_ops.impl( 1119 f"{TestCustomOp.test_ns}::foo", device_types=invalid_type 1120 )(foo_impl) 1121 1122 def test_backward_partially_registered(self): 1123 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1124 def foo(x: torch.Tensor) -> torch.Tensor: 1125 raise NotImplementedError 1126 1127 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1128 def foo_impl(x): 1129 return x.sin() 1130 1131 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1132 def foo_backward(ctx, saved, grad): 1133 return grad * saved.cos() 1134 1135 x = torch.randn([], requires_grad=True) 1136 op = self.get_op(f"{self.test_ns}::foo") 1137 with self.assertRaisesRegex( 1138 RuntimeError, "unable to find a 'save_for_backward'" 1139 ): 1140 y = op(x) 1141 y.backward() 1142 1143 def test_save_for_backward_inputs_are_namedtuple(self): 1144 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1145 def foo(x: torch.Tensor) -> torch.Tensor: 1146 raise NotImplementedError 1147 1148 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1149 def foo_impl(x): 1150 return x.sin() 1151 1152 hit = 0 1153 1154 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1155 def foo_save_for_backward(inputs, output): 1156 nonlocal hit 1157 hit += 1 1158 self.assertTrue(isinstance(inputs, tuple)) 1159 self.assertEqual(list(inputs._asdict().keys()), ["x"]) 1160 return inputs.x 1161 1162 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1163 def foo_backward(ctx, saved, grad): 1164 return {"x": grad * saved.cos()} 1165 1166 x = torch.randn([], requires_grad=True) 1167 op = self.get_op(f"{self.test_ns}::foo") 1168 y = op(x) 1169 self.assertEqual(hit, 1) 1170 y.backward() 1171 self.assertEqual(hit, 1) 1172 1173 def test_backward_returns_dict(self): 1174 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1175 def foo(x: torch.Tensor) -> torch.Tensor: 1176 raise NotImplementedError 1177 1178 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1179 def foo_impl(x): 1180 return x.sin() 1181 1182 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1183 def foo_save_for_backward(inputs, output): 1184 return inputs.x 1185 1186 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1187 def foo_backward(ctx, saved, grad): 1188 return grad * saved.cos() 1189 1190 x = torch.randn([], requires_grad=True) 1191 op = self.get_op(f"{self.test_ns}::foo") 1192 y = op(x) 1193 with self.assertRaisesRegex(RuntimeError, "to be a dict"): 1194 y.backward() 1195 1196 def test_backward_dict_invalid_keys(self): 1197 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1198 def foo(x: torch.Tensor) -> torch.Tensor: 1199 raise NotImplementedError 1200 1201 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1202 def foo_impl(x): 1203 return x.sin() 1204 1205 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1206 def foo_save_for_backward(inputs, output): 1207 return inputs.x 1208 1209 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1210 def foo_backward(ctx, saved, grad): 1211 return {"x": grad * saved.cos(), "y": None} 1212 1213 x = torch.randn([], requires_grad=True) 1214 op = self.get_op(f"{self.test_ns}::foo") 1215 y = op(x) 1216 with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"): 1217 y.backward() 1218 1219 def test_backward_dict_grad_for_nontensor(self): 1220 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1221 def foo(x: torch.Tensor, dim: int) -> torch.Tensor: 1222 raise NotImplementedError 1223 1224 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1225 def foo_impl(x, dim): 1226 return x.sin() 1227 1228 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1229 def foo_save_for_backward(inputs, output): 1230 return inputs.x 1231 1232 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1233 def foo_backward(ctx, saved, grad): 1234 return {"x": grad * saved.cos(), "dim": None} 1235 1236 x = torch.randn([], requires_grad=True) 1237 op = self.get_op(f"{self.test_ns}::foo") 1238 y = op(x, 32) 1239 with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"): 1240 y.backward() 1241 1242 def test_backward_dict_requires_keys_for_input_tensors(self): 1243 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1244 def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 1245 raise NotImplementedError 1246 1247 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1248 def foo_impl(x, y): 1249 return x.sin() 1250 1251 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1252 def foo_save_for_backward(inputs, output): 1253 return inputs.x 1254 1255 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1256 def foo_backward(ctx, saved, grad): 1257 return {"x": grad * saved.cos()} 1258 1259 x = torch.randn([], requires_grad=True) 1260 op = self.get_op(f"{self.test_ns}::foo") 1261 y = op(x, x) 1262 with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): 1263 y.backward() 1264 1265 def test_backward_dict_requires_keys_for_input_optional_tensors(self): 1266 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1267 def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: 1268 raise NotImplementedError 1269 1270 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1271 def foo_impl(x, y): 1272 return x.sin() 1273 1274 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1275 def foo_save_for_backward(inputs, output): 1276 return inputs.x 1277 1278 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1279 def foo_backward(ctx, saved, grad): 1280 return {"x": grad * saved.cos()} 1281 1282 x = torch.randn([], requires_grad=True) 1283 op = self.get_op(f"{self.test_ns}::foo") 1284 y = op(x, None) 1285 with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): 1286 y.backward() 1287 1288 def test_backward_grads_are_tensor_or_none(self): 1289 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1290 def foo(x: torch.Tensor) -> torch.Tensor: 1291 raise NotImplementedError 1292 1293 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1294 def foo_impl(x): 1295 return x.sin() 1296 1297 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1298 def foo_save_for_backward(inputs, output): 1299 return inputs.x 1300 1301 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1302 def foo_backward(ctx, saved, grad): 1303 return {"x": (grad * saved.cos(),)} 1304 1305 x = torch.randn([], requires_grad=True) 1306 op = self.get_op(f"{self.test_ns}::foo") 1307 y = op(x) 1308 with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"): 1309 y.backward() 1310 1311 def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self): 1312 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1313 def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: 1314 raise NotImplementedError 1315 1316 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1317 def foo_impl(xs): 1318 return xs[0].sin() 1319 1320 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1321 def foo_save_for_backward(inputs, output): 1322 return inputs.xs[0] 1323 1324 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1325 def foo_backward(ctx, saved, grad): 1326 return {"xs": [grad * saved.cos(), None]} 1327 1328 xs = [torch.randn([], requires_grad=True) for _ in range(3)] 1329 op = self.get_op(f"{self.test_ns}::foo") 1330 y = op(xs) 1331 with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"): 1332 y.backward() 1333 1334 def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self): 1335 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1336 def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: 1337 raise NotImplementedError 1338 1339 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1340 def foo_impl(xs): 1341 return xs[0].sin() 1342 1343 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1344 def foo_save_for_backward(inputs, output): 1345 return inputs.xs[0] 1346 1347 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1348 def foo_backward(ctx, saved, grad): 1349 return {"xs": [grad * saved.cos(), None, (None,)]} 1350 1351 xs = [torch.randn([], requires_grad=True) for _ in range(3)] 1352 op = self.get_op(f"{self.test_ns}::foo") 1353 y = op(xs) 1354 with self.assertRaisesRegex(RuntimeError, "None or Tensor"): 1355 y.backward() 1356 1357 def test_backward_tensorlist_input_requires_list_grads(self): 1358 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1359 def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: 1360 raise NotImplementedError 1361 1362 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1363 def foo_impl(xs): 1364 return xs[0].sin() 1365 1366 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1367 def foo_save_for_backward(inputs, output): 1368 return inputs.xs[0] 1369 1370 @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") 1371 def foo_backward(ctx, saved, grad): 1372 return {"xs": None} 1373 1374 xs = [torch.randn([], requires_grad=True) for _ in range(3)] 1375 op = self.get_op(f"{self.test_ns}::foo") 1376 y = op(xs) 1377 with self.assertRaisesRegex(RuntimeError, "list of gradients"): 1378 y.backward() 1379 1380 def test_backward_output_differentiability_type(self): 1381 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1382 def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: 1383 raise NotImplementedError 1384 1385 with self.assertRaisesRegex(RuntimeError, "output_differentiability"): 1386 1387 @custom_ops.impl_backward( 1388 f"{TestCustomOp.test_ns}::foo", output_differentiability=True 1389 ) 1390 def foo_backward(ctx, saved, grad): 1391 return {"xs": None} 1392 1393 def test_backward_output_differentiability_numel(self): 1394 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1395 def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: 1396 raise NotImplementedError 1397 1398 with self.assertRaisesRegex(RuntimeError, "output_differentiability"): 1399 1400 @custom_ops.impl_backward( 1401 f"{TestCustomOp.test_ns}::foo", output_differentiability=[True] 1402 ) 1403 def foo_backward(ctx, saved, grad): 1404 return {"xs": None} 1405 1406 def test_backward_output_differentiability_tensorlist(self): 1407 @custom_ops.custom_op(f"{self.test_ns}::foo") 1408 def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]: 1409 raise NotImplementedError 1410 1411 @custom_ops.impl(f"{self.test_ns}::foo") 1412 def foo_impl(x): 1413 return [x.clone(), x.clone()], x.clone() 1414 1415 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1416 def foo_save_for_backward(inputs, output): 1417 return [] 1418 1419 @custom_ops.impl_backward( 1420 f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True] 1421 ) 1422 def foo_backward(ctx, saved, grad_lst, grad): 1423 return {"x": grad} 1424 1425 op = self.get_op(f"{self.test_ns}::foo") 1426 x = torch.randn(3, requires_grad=True) 1427 [a, b], c = op(x) 1428 self.assertFalse(a.requires_grad) 1429 self.assertFalse(b.requires_grad) 1430 self.assertTrue(c.requires_grad) 1431 1432 def test_backward_output_differentiability_non_tensor(self): 1433 @custom_ops.custom_op(f"{self.test_ns}::foo") 1434 def foo(x: Tensor) -> Tuple[Tensor, int]: 1435 raise NotImplementedError 1436 1437 @custom_ops.impl(f"{self.test_ns}::foo") 1438 def foo_impl(x): 1439 return x.clone(), 3 1440 1441 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") 1442 def foo_save_for_backward(inputs, output): 1443 return [] 1444 1445 @custom_ops.impl_backward( 1446 f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True] 1447 ) 1448 def foo_backward(ctx, saved, grad0, grad1): 1449 return {"x": grad0} 1450 1451 op = self.get_op(f"{self.test_ns}::foo") 1452 x = torch.randn(3, requires_grad=True) 1453 with self.assertRaisesRegex(RuntimeError, "is not a Tensor"): 1454 op(x) 1455 1456 @unittest.skipIf(not TEST_CUDA, "requires CUDA") 1457 def test_impl_separate(self): 1458 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1459 def foo(x: torch.Tensor) -> torch.Tensor: 1460 raise NotImplementedError 1461 1462 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") 1463 def foo_cpu(x): 1464 return x.sin() 1465 1466 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda") 1467 def foo_cuda(x): 1468 return x.cos() 1469 1470 x = torch.randn(3) 1471 op = self.get_op(f"{self.test_ns}::foo") 1472 result = op(x) 1473 self.assertEqual(result, foo_cpu(x)) 1474 1475 x_cuda = x.cuda() 1476 op = self.get_op(f"{self.test_ns}::foo") 1477 result = op(x_cuda) 1478 self.assertEqual(result, foo_cuda(x_cuda)) 1479 1480 @unittest.skipIf(not TEST_CUDA, "requires CUDA") 1481 def test_impl_multiple(self): 1482 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1483 def foo(x: torch.Tensor) -> torch.Tensor: 1484 raise NotImplementedError 1485 1486 @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") 1487 def foo_impl(x): 1488 return x.cos() 1489 1490 op = self.get_op(f"{self.test_ns}::foo") 1491 x = torch.randn(3) 1492 result = op(x) 1493 self.assertEqual(result, foo_impl(x)) 1494 1495 x_cuda = x.cuda() 1496 result = op(x_cuda) 1497 self.assertEqual(result, foo_impl(x_cuda)) 1498 1499 def test_impl_abstract_overload(self): 1500 lib = self.lib() 1501 lib.define("sin.blah(Tensor x) -> Tensor") 1502 1503 torch.library.impl_abstract( 1504 f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib 1505 ) 1506 1507 op = self.ns().sin.blah 1508 x = torch.randn(3, device="meta") 1509 op(x) 1510 1511 def test_impl_meta(self): 1512 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1513 def foo(x: torch.Tensor, dim: int) -> torch.Tensor: 1514 raise NotImplementedError 1515 1516 @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1517 def foo_meta(x, dim): 1518 output_shape = list(x.shape) 1519 del output_shape[dim] 1520 return x.new_empty(output_shape) 1521 1522 x = torch.randn(2, 3, device="meta") 1523 op = self.get_op(f"{self.test_ns}::foo") 1524 result = op(x, 1) 1525 self.assertEqual(result.shape, foo_meta(x, 1).shape) 1526 1527 def test_duplicate_impl(self): 1528 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1529 def foo(x: torch.Tensor, dim: int) -> torch.Tensor: 1530 raise NotImplementedError 1531 1532 @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1533 def foo_meta(x, dim): 1534 output_shape = list(x.shape) 1535 del output_shape[dim] 1536 return x.new_empty(output_shape) 1537 1538 with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"): 1539 1540 @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1541 def foo_meta2(x, dim): 1542 output_shape = list(x.shape) 1543 del output_shape[dim] 1544 return x.new_empty(output_shape) 1545 1546 def test_new_data_dependent_symint(self): 1547 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1548 def foo(x: torch.Tensor) -> torch.Tensor: 1549 raise NotImplementedError 1550 1551 @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1552 def foo_meta(x): 1553 ctx = torch.library.get_ctx() 1554 r = ctx.new_dynamic_size(min=1) 1555 with self.assertRaisesRegex(ValueError, "greater than or equal to 0"): 1556 ctx.new_dynamic_size(min=-1) 1557 with self.assertRaisesRegex(ValueError, "SymInt"): 1558 ctx.new_dynamic_size(max=x.numel()) 1559 # NB: You must return dynamic sizes! 1560 return x.new_empty(r) 1561 1562 x = torch.randn(2, 3, device="cpu") 1563 op = self.get_op(f"{self.test_ns}::foo") 1564 make_fx(op, tracing_mode="symbolic")(x) 1565 1566 def test_meta_for_data_dependent_shape_operation(self): 1567 x = torch.randn(10, device="meta") 1568 with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"): 1569 numpy_nonzero(x) 1570 1571 def test_basic_make_fx(self): 1572 # More serious tests are in our CustomOp opinfo db, 1573 # this one is just a sanity check. 1574 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1575 def foo(x: torch.Tensor) -> torch.Tensor: 1576 raise NotImplementedError 1577 1578 @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) 1579 def foo_meta(x): 1580 return x.sum() 1581 1582 x = torch.randn(3) 1583 op = self.get_op(f"{self.test_ns}::foo") 1584 gm = make_fx(op, tracing_mode="symbolic")(x) 1585 self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code) 1586 1587 def test_not_implemented_error(self): 1588 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") 1589 def foo(x: torch.Tensor) -> torch.Tensor: 1590 raise NotImplementedError 1591 1592 x = torch.randn(3) 1593 op = self.get_op(f"{self.test_ns}::foo") 1594 with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"): 1595 op(x) 1596 1597 x = torch.randn(3, device="meta") 1598 with self.assertRaisesRegex(NotImplementedError, "no fake impl or Meta kernel"): 1599 op(x) 1600 1601 @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar") 1602 def bar(sizes: Sequence[int]) -> torch.Tensor: 1603 raise NotImplementedError 1604 1605 op = self.get_op(f"{self.test_ns}::bar") 1606 with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"): 1607 op((1, 2, 3)) 1608 1609 def test_data_dependent_basic(self): 1610 x = torch.randn(5, 5) 1611 gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x) 1612 self.assertTrue("nonzero" in gm.code) 1613 1614 def test_data_dependent_fake_tracing(self): 1615 x = torch.randn(5, 5) 1616 # We've updated to attempt to use unbacked symints even for fake 1617 # tracing 1618 make_fx(numpy_nonzero, tracing_mode="fake")(x) 1619 1620 def test_symints(self): 1621 def f(x): 1622 return torch.ops._torch_testing.numpy_view_copy(x, x.shape) 1623 1624 x = torch.randn(2, 3, 4) 1625 gm = make_fx(f, tracing_mode="symbolic")(x) 1626 result = gm(x) 1627 self.assertEqual(result, f(x)) 1628 self.assertExpectedInline( 1629 gm.code.strip(), 1630 """\ 1631def forward(self, x_1): 1632 sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 1633 sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) 1634 sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2) 1635 numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None 1636 return numpy_view_copy""", # noqa: B950 1637 ) 1638 1639 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") 1640 def test_data_dependent_compile(self): 1641 import torch._dynamo.testing 1642 from torch._dynamo.utils import counters 1643 1644 counters.clear() 1645 cnt = torch._dynamo.testing.CompileCounter() 1646 1647 @torch.compile(backend=cnt) 1648 def f(x): 1649 return numpy_nonzero(x.clone()).clone() 1650 1651 f(torch.randn(10)) 1652 1653 self.assertEqual(len(counters["graph_break"]), 1) 1654 self.assertEqual(next(iter(counters["graph_break"].values())), 1) 1655 self.assertExpectedInline( 1656 next(iter(counters["graph_break"].keys())).replace(";", "\n"), 1657 """\ 1658dynamic shape operator: _torch_testing.numpy_nonzero.default 1659 to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""", 1660 ) 1661 1662 # pre-existing problem: torch.compile(dynamic=True) will, by default, 1663 # graph break on data-dependent operations. Eventually we'll make it so 1664 # that it never graph breaks on data-dependent operations. 1665 @unittest.expectedFailure 1666 def test_data_dependent_nms_dynamic_compile(self): 1667 import torch._dynamo.testing 1668 from torch._dynamo.utils import counters 1669 1670 counters.clear() 1671 cnt = torch._dynamo.testing.CompileCounter() 1672 1673 @torch.compile(backend=cnt, dynamic=True) 1674 def f(x, s, i): 1675 return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone() 1676 1677 f(torch.randn(20, 4), torch.randn(20), 0.1) 1678 1679 self.assertEqual(len(counters["graph_break"]), 0) 1680 1681 def test_impl_on_existing_op(self): 1682 lib = self.lib() 1683 lib.define("foo(Tensor x) -> Tensor") 1684 qualname = f"{self.test_ns}::foo" 1685 1686 @torch._custom_ops.impl(qualname) 1687 def foo_impl(x): 1688 return x.sin() 1689 1690 op = self.get_op(qualname) 1691 x = torch.randn(3) 1692 result = op(x) 1693 self.assertEqual(result, x.sin()) 1694 1695 @parametrize( 1696 "key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"] 1697 ) 1698 def test_impl_on_existing_op_with_cpu_registration(self, key): 1699 lib = self.lib() 1700 lib.define("foo(Tensor x) -> Tensor") 1701 qualname = f"{self.test_ns}::foo" 1702 1703 def foo_impl(x): 1704 return x.sin() 1705 1706 lib.impl("foo", foo_impl, key) 1707 op = self.get_op(qualname) 1708 1709 with self.assertRaisesRegex(RuntimeError, "already has an implementation"): 1710 custom_ops.impl(qualname, func=foo_impl) 1711 1712 def test_abstract_impl_on_existing_op(self): 1713 lib = self.lib() 1714 lib.define("foo(Tensor x) -> Tensor") 1715 qualname = f"{self.test_ns}::foo" 1716 1717 @torch.library.impl_abstract(qualname, lib=self.lib()) 1718 def foo_impl(x): 1719 return x.sin() 1720 1721 op = self.get_op(qualname) 1722 with torch._subclasses.FakeTensorMode(): 1723 x = torch.randn(3) 1724 result = op(x) 1725 self.assertEqual(result.shape, x.shape) 1726 self.assertEqual(result.stride(), x.stride()) 1727 1728 def test_abstract_impl_on_existing_op_with_meta(self): 1729 lib = self.lib() 1730 lib.define("foo(Tensor x) -> Tensor") 1731 qualname = f"{self.test_ns}::foo" 1732 1733 def foo_impl(x): 1734 return x.sin() 1735 1736 lib.impl("foo", foo_impl, "Meta") 1737 op = self.get_op(qualname) 1738 1739 with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"): 1740 torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) 1741 1742 def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self): 1743 lib = self.lib() 1744 lib.define("foo(Tensor x) -> Tensor") 1745 qualname = f"{self.test_ns}::foo" 1746 1747 def foo_impl(x): 1748 return x.sin() 1749 1750 lib.impl("foo", foo_impl, "CompositeImplicitAutograd") 1751 op = self.get_op(qualname) 1752 1753 with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"): 1754 torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) 1755 1756 def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self): 1757 lib = self.lib() 1758 lib.define("foo(Tensor x) -> Tensor") 1759 qualname = f"{self.test_ns}::foo" 1760 1761 def foo_impl(x): 1762 return x.sin() 1763 1764 lib.impl("foo", foo_impl, "CompositeExplicitAutograd") 1765 op = self.get_op(qualname) 1766 1767 torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib()) 1768 with torch._subclasses.FakeTensorMode(): 1769 x = torch.randn(10) 1770 result = op(x) 1771 self.assertEqual(result.shape, ()) 1772 1773 def _test_backward_impl_raises(self, qualname, err_regex): 1774 with self.assertRaisesRegex(RuntimeError, err_regex): 1775 1776 @custom_ops.impl_save_for_backward(qualname) 1777 def foo2(x): 1778 return 1779 1780 with self.assertRaisesRegex(RuntimeError, err_regex): 1781 1782 @custom_ops.impl_backward(qualname) 1783 def foo3(x): 1784 return 1785 1786 def test_backward_impl_on_existing_op_incorrect_schema_views(self): 1787 lib = self.lib() 1788 lib.define("foo(Tensor(a) x) -> Tensor(a)") 1789 qualname = f"{self.test_ns}::foo" 1790 self._test_backward_impl_raises(qualname, "operator that returns views") 1791 1792 def test_backward_impl_on_existing_op_incorrect_schema_mutable(self): 1793 lib = self.lib() 1794 lib.define("foo(Tensor(a!) x) -> Tensor") 1795 qualname = f"{self.test_ns}::foo" 1796 self._test_backward_impl_raises(qualname, "non-functional") 1797 1798 def test_backward_impl_on_existing_op_incorrect_schema_no_output(self): 1799 lib = self.lib() 1800 lib.define("foo(Tensor x) -> ()") 1801 qualname = f"{self.test_ns}::foo" 1802 self._test_backward_impl_raises(qualname, "no returns") 1803 1804 def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self): 1805 lib = self.lib() 1806 lib.define("foo(Tensor x) -> Tensor") 1807 qualname = f"{self.test_ns}::foo" 1808 lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd") 1809 self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd") 1810 1811 @parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"]) 1812 def test_backward_impl_on_existing_op_with_key(self, key): 1813 lib = self.lib() 1814 lib.define("foo(Tensor x) -> Tensor") 1815 qualname = f"{self.test_ns}::foo" 1816 lib.impl("foo", lambda x: x.sin().cos(), key) 1817 self._test_backward_impl_raises(qualname, key) 1818 1819 def test_is_functional_schema(self): 1820 tests = { 1821 "foo(Tensor x) -> Tensor": True, 1822 "foo(Tensor(a) x) -> Tensor": True, 1823 "foo(Tensor(a!) x) -> Tensor": False, 1824 "foo(Tensor(a) x) -> Tensor(a)": False, 1825 "foo(Tensor x) -> ()": False, 1826 } 1827 for schema_str, expected in tests.items(): 1828 res = torch._library.utils.is_functional_schema(schema_str) 1829 self.assertEqual(res, expected) 1830 1831 from torchgen.model import FunctionSchema 1832 1833 schema = FunctionSchema.parse(schema_str) 1834 res = torch._library.utils.is_functional_schema(schema) 1835 self.assertEqual(res, expected) 1836 1837 schema = torch._C.parse_schema(schema_str) 1838 res = torch._library.utils.is_functional_schema(schema) 1839 self.assertEqual(res, expected) 1840 1841 def test_incorrect_schema_types(self): 1842 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 1843 with self.assertRaisesRegex(RuntimeError, "unknown type specifier"): 1844 lib.define("foo12(Tensor a) -> asdfasdf") 1845 with self.assertRaisesRegex(RuntimeError, "unknown type specifier"): 1846 lib.define("foo12(asdf a) -> Tensor") 1847 with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"): 1848 lib.define("foo12(int64_t a) -> Tensor") 1849 with self.assertRaisesRegex(RuntimeError, "Use `float`"): 1850 lib.define("foo12(double a) -> Tensor") 1851 1852 def test_is_tensorlist_like_type(self): 1853 tensorlists = [ 1854 # Tensor[] 1855 torch.ops.aten.where.default._schema.returns[0].type, 1856 # Tensor?[] 1857 torch.ops.aten.index.Tensor._schema.arguments[1].type, 1858 # Tensor[]? 1859 torch._C.parse_schema("foo(Tensor[]? x) -> ()").arguments[0].type, 1860 # Tensor?[]? 1861 torch._C.parse_schema("foo(Tensor?[]? x) -> ()").arguments[0].type, 1862 ] 1863 non_tensorlists = [ 1864 # Tensor 1865 torch.ops.aten.sin.default._schema.arguments[0].type, 1866 # IntList 1867 torch.ops.aten.sum.dim_IntList._schema.arguments[1].type, 1868 ] 1869 for a in tensorlists: 1870 self.assertTrue(torch._library.utils.is_tensorlist_like_type(a)) 1871 for a in non_tensorlists: 1872 self.assertFalse(torch._library.utils.is_tensorlist_like_type(a)) 1873 1874 def test_backward_impl_on_existing_op(self): 1875 lib = self.lib() 1876 lib.define("foo(Tensor x) -> Tensor") 1877 qualname = f"{self.test_ns}::foo" 1878 1879 @custom_ops.impl(qualname) 1880 def foo_impl(x): 1881 with torch.no_grad(): 1882 return x.sin() 1883 1884 @custom_ops.impl_save_for_backward(qualname) 1885 def foo_save_for_backward(inputs, output): 1886 return inputs.x 1887 1888 @custom_ops.impl_backward(qualname) 1889 def foo_backward(ctx, saved, grad_out): 1890 return {"x": grad_out * saved.cos()} 1891 1892 op = self.get_op(qualname) 1893 x = torch.randn([], requires_grad=True) 1894 y = op(x) 1895 (gx,) = torch.autograd.grad(y, x) 1896 self.assertEqual(gx, x.cos()) 1897 1898 @parametrize( 1899 "tags", 1900 [ 1901 subtest(torch.Tag.pointwise, "single"), 1902 subtest((torch.Tag.pointwise,), "tuple"), 1903 subtest([torch.Tag.pointwise], "list"), 1904 ], 1905 ) 1906 def test_define_with_tags(self, tags): 1907 lib = self.lib() 1908 tags = (torch.Tag.pointwise,) 1909 torch.library.define( 1910 f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags 1911 ) 1912 actual = self.ns().foo.default.tags 1913 self.assertTrue(isinstance(actual, list)) 1914 self.assertEqual(actual, list(tags)) 1915 1916 def test_builtin_aten_ops_are_pt2_compliant(self): 1917 for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]: 1918 self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) 1919 1920 def test_builtin_torchscript_ops(self): 1921 for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]: 1922 self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) 1923 1924 def test_autogen_aten_ops_are_pt2_compliant(self): 1925 for op in [torch.ops.aten.fill.Tensor_out]: 1926 self.assertIn(torch.Tag.generated, op.tags) 1927 self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) 1928 1929 def test_resolve_packet(self): 1930 x = torch.randn(3) 1931 result = torch._C._jit_resolve_packet("aten::sum", x) 1932 self.assertEqual(result, "default") 1933 1934 result = torch._C._jit_resolve_packet("aten::sum", x, dim=1) 1935 self.assertEqual(result, "dim_IntList") 1936 1937 with self.assertRaisesRegex(RuntimeError, "failed to match any schema"): 1938 result = torch._C._jit_resolve_packet("aten::sum", x, x, x) 1939 1940 def test_define_bad_schema(self): 1941 lib = self.lib() 1942 with self.assertRaisesRegex(ValueError, "expected schema to look like"): 1943 torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor") 1944 1945 def test_define_and_impl(self): 1946 lib = self.lib() 1947 torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 1948 1949 @torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib) 1950 def f(x): 1951 return torch.from_numpy(np.sin(x.numpy())) 1952 1953 x = torch.randn(3) 1954 y = self.ns().foo(x) 1955 assert torch.allclose(y, x.sin()) 1956 1957 def test_define_validation(self): 1958 with self.assertRaisesRegex(ValueError, "namespace"): 1959 torch.library.define("foo", "(Tensor x) -> Tensor") 1960 1961 def test_legacy_define(self): 1962 lib = self.lib() 1963 1964 @torch.library.define(lib, "foo(Tensor x) -> Tensor") 1965 def f(x): 1966 return torch.from_numpy(np.sin(x.numpy())) 1967 1968 x = torch.randn(3) 1969 y = self.ns().foo(x) 1970 assert torch.allclose(y, x.sin()) 1971 1972 def test_impl_function(self): 1973 lib = self.lib() 1974 torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 1975 1976 def f(x): 1977 return torch.from_numpy(np.sin(x.numpy())) 1978 1979 torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib) 1980 x = torch.randn(3) 1981 y = self.ns().foo(x) 1982 assert torch.allclose(y, x.sin()) 1983 1984 def test_legacy_impl(self): 1985 lib = self.lib() 1986 torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 1987 1988 @torch.library.impl(lib, "foo", "CPU") 1989 def f(x): 1990 return torch.from_numpy(np.sin(x.numpy())) 1991 1992 x = torch.randn(3) 1993 y = self.ns().foo(x) 1994 assert torch.allclose(y, x.sin()) 1995 1996 def test_defined_in_python(self): 1997 self.assertFalse(torch.ops.aten.sin.default._defined_in_python) 1998 self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python) 1999 2000 lib = self.lib() 2001 torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 2002 ns = self.ns() 2003 self.assertTrue(ns.foo.default._defined_in_python) 2004 2005 torch.library.define( 2006 "{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib 2007 ) 2008 self.assertTrue(ns.bar.overload._defined_in_python) 2009 2010 def _test_impl_device(self, name, types, device): 2011 lib = self.lib() 2012 torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib) 2013 2014 @torch.library.impl(f"{self.test_ns}::{name}", types) 2015 def f(x): 2016 x_np = x.cpu().numpy() 2017 y = torch.from_numpy(np.sin(x_np)) 2018 return y.to(device=x.device) 2019 2020 x = torch.randn(3, device=device) 2021 y = getattr(self.ns(), name)(x) 2022 assert torch.allclose(y, x.sin()) 2023 2024 def test_impl_device_cpu(self): 2025 self._test_impl_device("foo1", "default", "cpu") 2026 self._test_impl_device("foo2", ["cpu"], "cpu") 2027 self._test_impl_device("foo3", ["cpu", "cuda"], "cpu") 2028 2029 @unittest.skipIf(not TEST_CUDA, "requires cuda") 2030 def test_impl_device_cuda(self): 2031 self._test_impl_device("foo4", "default", "cuda") 2032 self._test_impl_device("foo5", ["cuda"], "cuda") 2033 self._test_impl_device("foo6", ["cpu", "cuda"], "cuda") 2034 2035 def test_impl_device_function(self): 2036 lib = self.lib() 2037 torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) 2038 2039 def f(x): 2040 x_np = x.cpu().numpy() 2041 y = torch.from_numpy(np.sin(x_np)) 2042 return y.to(device=x.device) 2043 2044 torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib) 2045 x = torch.randn(3) 2046 y = self.ns().foo(x) 2047 assert torch.allclose(y, x.sin()) 2048 2049 def test_impl_device_invalid(self): 2050 with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"): 2051 torch.library.impl("blah::blah", "somethingsomething") 2052 2053 def test_autograd_function_backed_op(self): 2054 cpp_source = """ 2055struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 2056 static constexpr bool is_traceable = true; 2057 2058 static torch::Tensor forward( 2059 torch::autograd::AutogradContext* ctx, 2060 const torch::Tensor& x) { 2061 return x; 2062 } 2063 2064 static torch::autograd::variable_list backward( 2065 torch::autograd::AutogradContext *ctx, 2066 torch::autograd::variable_list grad_output) { 2067 return grad_output; 2068 } 2069}; 2070 2071torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) { 2072 return CustomOpAutogradFunction::apply(x); 2073} 2074 2075TORCH_LIBRARY(mylib, m) { 2076 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 2077} 2078 """ 2079 2080 module = torch.utils.cpp_extension.load_inline( 2081 name="mylib", 2082 cpp_sources=cpp_source, 2083 functions="custom_op_backed_by_autograd_fn", 2084 verbose=True, 2085 ) 2086 2087 x = torch.ones(2, 2, requires_grad=True) 2088 temp = x.clone().detach() 2089 out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x) 2090 loss = out.sum() 2091 loss.backward() 2092 self.assertEqual(x.grad, temp) 2093 2094 2095def op_with_incorrect_schema(testcase, name): 2096 lib = testcase.lib() 2097 lib.define(f"{name}(Tensor x) -> Tensor") 2098 qualname = f"{testcase.test_ns}::{name}" 2099 lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd") 2100 return testcase.get_op(qualname) 2101 2102 2103class MiniOpTest(CustomOpTestCaseBase): 2104 test_ns = "mini_op_test" 2105 2106 def _init_op_delayed_backward_error(self): 2107 name = "delayed_error" 2108 qualname = f"{self.test_ns}::{name}" 2109 lib = self.lib() 2110 lib.define(f"{name}(Tensor x) -> Tensor") 2111 lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd") 2112 op = self.get_op(qualname) 2113 2114 class Op(torch.autograd.Function): 2115 @staticmethod 2116 def forward(ctx, x): 2117 with torch._C._AutoDispatchBelowAutograd(): 2118 return op(x) 2119 2120 @staticmethod 2121 def backward(ctx, grad): 2122 raise NotImplementedError 2123 2124 def autograd_impl(x): 2125 return Op.apply(x) 2126 2127 lib.impl(name, autograd_impl, "Autograd") 2128 return op 2129 2130 def _init_op_with_no_abstract_impl(self): 2131 name = "no_abstract" 2132 qualname = f"{self.test_ns}::{name}" 2133 lib = self.lib() 2134 lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,)) 2135 lib.impl(name, lambda x: x.clone(), "CPU") 2136 return torch._library.utils.lookup_op(qualname) 2137 2138 def setUp(self): 2139 super().setUp() 2140 self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl() 2141 self._op_delayed_backward_error = self._init_op_delayed_backward_error() 2142 2143 @optests.dontGenerateOpCheckTests("Testing this API") 2144 def test_dont_generate(self): 2145 op = op_with_incorrect_schema(self, "incorrect_schema") 2146 x = torch.randn(3) 2147 op(x) 2148 2149 def test_mm(self): 2150 x = torch.randn(2, 3, requires_grad=True) 2151 y = torch.randn(3, 5) 2152 result = torch.ops.aten.mm.default(x, y) 2153 self.assertEqual(result, x @ y) 2154 2155 def test_mm_meta(self): 2156 x = torch.randn(2, 3, requires_grad=True, device="meta") 2157 y = torch.randn(3, 5, device="meta") 2158 result = torch.ops.aten.mm.default(x, y) 2159 self.assertEqual(result.shape, (x @ y).shape) 2160 2161 def test_mm_fake(self): 2162 with torch._subclasses.fake_tensor.FakeTensorMode(): 2163 x = torch.randn(2, 3, requires_grad=True, device="cpu") 2164 y = torch.randn(3, 5, device="cpu") 2165 result = torch.ops.aten.mm.default(x, y) 2166 self.assertEqual(result.shape, (x @ y).shape) 2167 2168 def test_mm_errors(self): 2169 x = torch.randn(2, 3, requires_grad=True) 2170 y = torch.randn(4, 5) 2171 with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"): 2172 result = torch.ops.aten.mm.default(x, y) 2173 2174 def test_nonzero(self): 2175 x = torch.tensor([0, 1, 2, 0, 0]) 2176 y = torch.ops.aten.nonzero.default(x) 2177 self.assertEqual(y, torch.tensor([[1], [2]])) 2178 2179 def test_inplace(self): 2180 x = torch.randn(3) 2181 x_clone = x.clone() 2182 y = torch.ops.aten.sin_(x) 2183 self.assertEqual(x, x_clone.sin()) 2184 2185 def test_incorrect_schema(self): 2186 op = op_with_incorrect_schema(self, "incorrect_schema") 2187 x = torch.randn(3) 2188 op(x) 2189 2190 def test_no_abstract(self): 2191 op = self._op_with_no_abstract_impl 2192 x = torch.randn(3) 2193 op(x) 2194 2195 def test_delayed_error(self): 2196 op = self._op_delayed_backward_error 2197 x = torch.randn([], requires_grad=True) 2198 y = op(x) 2199 with self.assertRaises(NotImplementedError): 2200 y.sum().backward() 2201 2202 def test_delayed_error_no_requires_grad(self): 2203 op = self._op_delayed_backward_error 2204 x = torch.randn([]) 2205 y = op(x) 2206 2207 2208class TestCustomOpAPI(TestCase): 2209 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2210 def test_basic(self): 2211 @torch.library.custom_op("_torch_testing::add", mutates_args=()) 2212 def add(x: Tensor, y: float) -> Tensor: 2213 x_np = x.numpy(force=True) 2214 out_np = x_np + y 2215 return torch.from_numpy(out_np).to(x.device) 2216 2217 x = torch.randn(3) 2218 y = 3.14 2219 z = add(x, y) 2220 self.assertEqual(z, x + y) 2221 2222 cpu_called = False 2223 2224 @add.register_kernel("cpu") 2225 def _(x, y): 2226 nonlocal cpu_called 2227 cpu_called = True 2228 x_np = x.numpy() 2229 out_np = x_np + y 2230 return torch.from_numpy(out_np) 2231 2232 z = add(x, y) 2233 self.assertEqual(z, x + y) 2234 self.assertTrue(cpu_called) 2235 2236 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2237 def test_no_grad_skips_autograd(self): 2238 @torch.library.custom_op("_torch_testing::add", mutates_args=()) 2239 def add(x: Tensor, y: float) -> Tensor: 2240 x_np = x.numpy(force=True) 2241 out_np = x_np + y 2242 return torch.from_numpy(out_np).to(x.device) 2243 2244 called = 0 2245 2246 def setup_context(ctx, inputs, output): 2247 nonlocal called 2248 called += 1 2249 2250 def backward(ctx, grad): 2251 raise AssertionError("should not be reached") 2252 2253 add.register_autograd(backward, setup_context=setup_context) 2254 2255 x = torch.randn(3, requires_grad=True) 2256 with torch.no_grad(): 2257 y = add(x, 2.0) 2258 self.assertEqual(called, 0) 2259 self.assertEqual(y, x + 2.0) 2260 2261 x.requires_grad_(False) 2262 y = add(x, 2.0) 2263 self.assertEqual(called, 0) 2264 self.assertEqual(y, x + 2.0) 2265 2266 x = torch.randn(3, requires_grad=True) 2267 y = add(x, 2.0) 2268 self.assertEqual(called, 1) 2269 self.assertEqual(y, x + 2.0) 2270 2271 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2272 def test_manual_schema(self): 2273 @torch.library.custom_op( 2274 "_torch_testing::add", 2275 mutates_args=(), 2276 schema="(Tensor x, float y) -> Tensor", 2277 ) 2278 def add(x, y): 2279 x_np = x.numpy(force=True) 2280 out_np = x_np + y 2281 return torch.from_numpy(out_np).to(x.device) 2282 2283 x = torch.randn(3) 2284 y = 3.14 2285 z = add(x, y) 2286 self.assertEqual(z, x + y) 2287 2288 @torch.library.custom_op( 2289 "_torch_testing::sin_", 2290 mutates_args=["x"], 2291 schema="(Tensor(a!) x) -> ()", 2292 ) 2293 def sin_(x): 2294 x_np = x.numpy() 2295 np.sin(x_np, out=x_np) 2296 2297 x = torch.randn(3) 2298 expected = x.sin() 2299 sin_(x) 2300 self.assertEqual(x, expected) 2301 2302 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2303 def test_kwarg_only_tensors(self): 2304 with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2305 2306 @torch.library.custom_op("_torch_testing::foo", mutates_args=()) 2307 def foo(x: Tensor, *, y: int, z: Tensor) -> Tensor: 2308 pass 2309 2310 with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2311 2312 @torch.library.custom_op("_torch_testing::foo", mutates_args=()) 2313 def foo2(x: Tensor, *, y: int, z: Optional[Tensor]) -> Tensor: 2314 pass 2315 2316 with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2317 2318 @torch.library.custom_op("_torch_testing::foo", mutates_args=()) 2319 def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor: 2320 pass 2321 2322 with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2323 lib.define("foo(Tensor x, *, Tensor y) -> Tensor") 2324 with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2325 torch.library.register_autograd( 2326 "_torch_testing::foo", 2327 lambda grad: grad, 2328 setup_context=lambda ctx, inputs, keyword_only_inputs, output: None, 2329 ) 2330 2331 with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): 2332 torch.library.register_vmap( 2333 "_torch_testing::foo", 2334 lambda info, in_dims, x, *, y: (x, 0), 2335 ) 2336 2337 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2338 def test_register_autograd_kwargonly_low_level(self): 2339 with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2340 lib.define("foo(Tensor x, *, float y) -> Tensor") 2341 called = False 2342 2343 def foo_impl(x, *, y): 2344 return x * y 2345 2346 lib.impl("foo", foo_impl, "CPU") 2347 2348 def backward(ctx, grad): 2349 nonlocal called 2350 called = True 2351 return grad * ctx.y 2352 2353 def setup_context(ctx, inputs, keyword_only_inputs, output): 2354 assert tuple(keyword_only_inputs.keys()) == ("y",) 2355 ctx.y = keyword_only_inputs["y"] 2356 2357 torch.library.register_autograd( 2358 "_torch_testing::foo", backward, setup_context=setup_context, lib=lib 2359 ) 2360 2361 x = torch.randn(3, requires_grad=True) 2362 torch.ops._torch_testing.foo(x, y=3.14).sum().backward() 2363 self.assertTrue(called) 2364 self.assertEqual(x.grad, torch.tensor([3.14, 3.14, 3.14])) 2365 2366 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2367 def test_register_autograd_defaults(self): 2368 with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2369 lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") 2370 2371 def foo_impl(w, x=2, *, y=3, z): 2372 return w * x * y * z 2373 2374 lib.impl("foo", foo_impl, "CPU") 2375 2376 called = False 2377 2378 def backward(ctx, grad): 2379 nonlocal called 2380 called = True 2381 return grad * ctx.c 2382 2383 def setup_context(ctx, inputs, keyword_only_inputs, output): 2384 assert len(inputs) == 2 2385 assert inputs[1] == 2 2386 assert keyword_only_inputs == {"y": 3, "z": 42} 2387 ctx.c = keyword_only_inputs["y"] * keyword_only_inputs["z"] * inputs[1] 2388 2389 torch.library.register_autograd( 2390 "_torch_testing::foo", backward, setup_context=setup_context, lib=lib 2391 ) 2392 2393 w = torch.randn(3, requires_grad=True) 2394 torch.ops._torch_testing.foo(w, z=42).sum().backward() 2395 self.assertTrue(called) 2396 self.assertEqual(w.grad, torch.full_like(w, 2 * 3 * 42)) 2397 2398 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2399 def test_manual_schema_error(self): 2400 with self.assertRaisesRegex(ValueError, "the op mutates {'x'}"): 2401 2402 @torch.library.custom_op( 2403 "_torch_testing::sin_", 2404 mutates_args=(), 2405 schema="(Tensor(a!) x) -> ()", 2406 ) 2407 def sin_(x): 2408 x_np = x.numpy() 2409 np.sin(x_np, out=x_np) 2410 2411 def test_supports_tensorlist(self): 2412 @torch._library.autograd.supports_tensorlist 2413 class Stack(torch.autograd.Function): 2414 @staticmethod 2415 def forward(ctx, xs): 2416 ctx.num_xs = len(xs) 2417 return torch.stack(xs) 2418 2419 @staticmethod 2420 def backward(ctx, grad): 2421 expected = ([True] * ctx.num_xs,) 2422 self.assertEqual(ctx.needs_input_grad, expected) 2423 return list(grad.unbind(0)) 2424 2425 # call two applys, do a backward on the first 2426 def t(): 2427 return torch.randn([], requires_grad=True) 2428 2429 xs0 = [t(), t(), t()] 2430 xs1 = [t(), t(), t(), t()] 2431 y0 = Stack.apply(xs0) 2432 y1 = Stack.apply(xs1) 2433 grads = torch.autograd.grad(y0.sum(), xs0) 2434 self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)]) 2435 2436 # call one apply, do multiple backwards 2437 xs = [t(), t(), t()] 2438 y = Stack.apply(xs) 2439 _ = torch.autograd.grad(y.sum(), xs, retain_graph=True) 2440 _ = torch.autograd.grad(y.sum(), xs, retain_graph=True) 2441 grads = torch.autograd.grad(y.sum(), xs, retain_graph=True) 2442 self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)]) 2443 2444 # error: on access forward, backward directly 2445 with self.assertRaisesRegex(NotImplementedError, "Function.forward directly"): 2446 Stack.forward(None, xs) 2447 with self.assertRaisesRegex(NotImplementedError, "Function.backward directly"): 2448 Stack.backward(None, xs) 2449 2450 # the recursive case 2451 @torch._library.autograd.supports_tensorlist 2452 class Foo(torch.autograd.Function): 2453 @staticmethod 2454 def forward(ctx, xs): 2455 if len(xs) > 1: 2456 return Foo.apply(xs[1:]) 2457 ctx.len_xs = len(xs) 2458 return xs[0].sin() 2459 2460 @staticmethod 2461 def backward(ctx, grad): 2462 result = [None] * ctx.len_xs 2463 result[-1] = grad.cos() 2464 return result 2465 2466 # should work 2467 result = Foo.apply(xs) 2468 expected = xs[-1].sin() 2469 self.assertEqual(result, expected) 2470 2471 # recursive on backward 2472 @torch._library.autograd.supports_tensorlist 2473 class Bar(torch.autograd.Function): 2474 @staticmethod 2475 def forward(ctx, xs): 2476 return [xs[i] + i for i in range(len(xs))] 2477 2478 @staticmethod 2479 def backward(ctx, grads): 2480 f1 = Bar.apply(grads[:2]) 2481 f2 = Bar.apply(grads[2:]) 2482 return f1 + f2 2483 2484 xs = [torch.tensor(0.0, requires_grad=True) for _ in range(5)] 2485 ys = Bar.apply(xs) 2486 sum(ys).backward() 2487 result = [xi.grad for xi in xs] 2488 self.assertEqual(result, torch.tensor([1.0, 2, 1, 2, 3]).unbind(0)) 2489 2490 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2491 def test_default_values(self): 2492 defaults = [] 2493 2494 @torch.library.custom_op("_torch_testing::f", mutates_args=()) 2495 def f( 2496 x: Tensor, 2497 a: Optional[int] = None, 2498 b: float = 3.14, 2499 c: bool = True, 2500 d: int = 3, 2501 e: str = "foo", 2502 f: torch.dtype = torch.float, 2503 g: torch.dtype = torch.float32, 2504 h: torch.dtype = torch.int, 2505 i: torch.device = torch.device("cpu:0"), 2506 j: torch.device = "cpu", 2507 ) -> Tensor: 2508 defaults.extend([a, b, c, d, e, f, g, h, i, j]) 2509 return x.clone() 2510 2511 x = torch.randn(3) 2512 f(x) 2513 self.assertEqual( 2514 defaults, 2515 [ 2516 None, 2517 3.14, 2518 True, 2519 3, 2520 "foo", 2521 torch.float, 2522 torch.float32, 2523 torch.int, 2524 torch.device("cpu:0"), 2525 "cpu", 2526 ], 2527 ) 2528 default_values = [ 2529 arg.default_value 2530 for arg in torch.ops._torch_testing.f.default._schema.arguments 2531 ] 2532 # enum values taken from c10/core/ScalarType.h 2533 type_enum = { 2534 "float": 6, 2535 "int": 3, 2536 } 2537 self.assertEqual( 2538 default_values, 2539 [ 2540 None, 2541 None, 2542 3.14, 2543 True, 2544 3, 2545 "foo", 2546 type_enum["float"], 2547 type_enum["float"], 2548 type_enum["int"], 2549 torch.device("cpu:0"), 2550 torch.device("cpu"), 2551 ], 2552 ) 2553 2554 def test_mutated_error(self): 2555 with self.assertRaisesRegex( 2556 ValueError, r".*{'y'} in mutates_args were not found" 2557 ): 2558 2559 @torch.library.custom_op( 2560 "_torch_testing::numpy_sin_inplace", 2561 mutates_args={"y"}, 2562 device_types="cpu", 2563 ) 2564 def numpy_sin_inplace(x: Tensor) -> None: 2565 x_np = x.numpy() 2566 np.sin(x_np, out=x_np) 2567 2568 def test_mutated(self): 2569 @torch.library.custom_op( 2570 "_torch_testing::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu" 2571 ) 2572 def numpy_sin_inplace(x: Tensor) -> None: 2573 x_np = x.numpy() 2574 np.sin(x_np, out=x_np) 2575 2576 x = torch.randn(3) 2577 version = x._version 2578 expected = x.sin() 2579 numpy_sin_inplace(x) 2580 self.assertEqual(x, expected) 2581 self.assertGreater(x._version, version) 2582 2583 @torch.library.custom_op("_torch_testing::f", mutates_args={"y", "z", "w"}) 2584 def f( 2585 x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]] 2586 ) -> None: 2587 return 2588 2589 x = torch.randn(3) 2590 y = torch.randn(3) 2591 z = [torch.randn(3), torch.randn(3)] 2592 w = [torch.randn(3), None, torch.randn(3)] 2593 initial_versions = pytree.tree_map_only( 2594 torch.Tensor, lambda x: x._version, (x, y, z, w) 2595 ) 2596 f(x, y, z, w) 2597 new_versions = pytree.tree_map_only( 2598 torch.Tensor, lambda x: x._version, (x, y, z, w) 2599 ) 2600 2601 self.assertEqual(initial_versions[0], new_versions[0]) 2602 initial_versions, _ = pytree.tree_flatten(initial_versions[1:]) 2603 new_versions, _ = pytree.tree_flatten(new_versions[1:]) 2604 for prev, after in zip(initial_versions, new_versions): 2605 if prev is None and after is None: 2606 continue 2607 self.assertGreater(after, prev) 2608 2609 def test_mutated_unknown(self): 2610 @torch.library.custom_op( 2611 "_torch_testing::f", mutates_args="unknown", device_types="cpu" 2612 ) 2613 def f(x: Tensor) -> None: 2614 x_np = x.numpy() 2615 np.sin(x_np, out=x_np) 2616 2617 x = torch.randn(3) 2618 version = x._version 2619 expected = x.sin() 2620 f(x) 2621 self.assertEqual(x, expected) 2622 self.assertGreater(x._version, version) 2623 2624 @torch.library.custom_op("_torch_testing::f2", mutates_args="unknown") 2625 def f2( 2626 x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]] 2627 ) -> None: 2628 return 2629 2630 x = torch.randn(3) 2631 y = torch.randn(3) 2632 z = [torch.randn(3), torch.randn(3)] 2633 w = [torch.randn(3), None, torch.randn(3)] 2634 initial_versions = pytree.tree_map_only( 2635 torch.Tensor, lambda x: x._version, (x, y, z, w) 2636 ) 2637 f2(x, y, z, w) 2638 new_versions = pytree.tree_map_only( 2639 torch.Tensor, lambda x: x._version, (x, y, z, w) 2640 ) 2641 2642 initial_versions, _ = pytree.tree_flatten(initial_versions) 2643 new_versions, _ = pytree.tree_flatten(new_versions) 2644 for prev, after in zip(initial_versions, new_versions): 2645 if prev is None and after is None: 2646 continue 2647 self.assertGreater(after, prev) 2648 2649 with self.assertRaisesRegex(ValueError, "string"): 2650 2651 @torch.library.custom_op("_torch_testing::f3", mutates_args="x") 2652 def f3(x: Tensor) -> None: 2653 return 2654 2655 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2656 def test_library_register_torch_dispatch_rule_subclass(self): 2657 from torch.testing._internal.two_tensor import TwoTensor 2658 2659 @torch.library.custom_op("mylib::foo", mutates_args={}) 2660 def f(x: torch.Tensor) -> torch.Tensor: 2661 return x.sin() 2662 2663 x = torch.randn(3) 2664 y = torch.randn(3) 2665 z = TwoTensor(x, y) 2666 2667 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 2668 called = 0 2669 2670 def TwoTensor_foo(cls, func, types, args, kwargs): 2671 nonlocal called 2672 assert cls is TwoTensor 2673 called += 1 2674 return x.sin() 2675 2676 m._register_torch_dispatch_rule("foo", TwoTensor, TwoTensor_foo) 2677 2678 out = f(z) 2679 out2 = z.cos() 2680 2681 self.assertEqual(called, 1) 2682 2683 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2684 def test_library_register_torch_dispatch_rule_mode(self): 2685 from torch.testing._internal.two_tensor import TwoTensorMode 2686 2687 @torch.library.custom_op("mylib::foo", mutates_args={}) 2688 def f(x: torch.Tensor) -> torch.Tensor: 2689 return x.sin() 2690 2691 x = torch.randn(3) 2692 2693 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 2694 called = 0 2695 2696 def TwoTensor_foo(mode, func, types, args, kwargs): 2697 nonlocal called 2698 called += 1 2699 return x.sin() 2700 2701 m._register_torch_dispatch_rule("foo", TwoTensorMode, TwoTensor_foo) 2702 2703 with TwoTensorMode(): 2704 out = f(x) 2705 out2 = x.cos() 2706 2707 self.assertEqual(called, 1) 2708 2709 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2710 @parametrize("idx", [0, 1, 2, 3, 4, 5]) 2711 def test_library_register_fake_source(self, idx): 2712 opname = f"source{idx}" 2713 op = getattr(torch.ops._torch_testing, opname).default 2714 entry = torch._library.simple_registry.singleton.find(op._name) 2715 source = entry.fake_impl.kernel.source 2716 assert source is not None 2717 self.assertTrue("custom_op_db.py" in source) 2718 2719 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2720 def test_library_register_fake(self): 2721 for mode in ["function", "qualname", "opoverload"]: 2722 2723 @torch.library.custom_op("_torch_testing::add", mutates_args=()) 2724 def add(x: Tensor, y: float) -> Tensor: 2725 x_np = x.cpu().numpy() 2726 out_np = x_np + y 2727 return torch.from_numpy(out_np).to(x.device) 2728 2729 called = False 2730 2731 if mode == "function": 2732 dec = torch.library.register_fake(add) 2733 self.assertIsNotNone(dec) 2734 elif mode == "qualname": 2735 dec = torch.library.register_fake("_torch_testing::add") 2736 self.assertIsNotNone(dec) 2737 elif mode == "opoverload": 2738 dec = torch.library.register_fake(torch.ops._torch_testing.add.default) 2739 self.assertIsNotNone(dec) 2740 else: 2741 raise AssertionError("should not get here") 2742 2743 @dec 2744 def _(x, y): 2745 nonlocal called 2746 called = True 2747 return torch.empty_like(x) 2748 2749 with torch._subclasses.fake_tensor.FakeTensorMode(): 2750 x = torch.randn(3) 2751 y = 3.14 2752 z = add(x, y) 2753 self.assertEqual(z.shape, x.shape) 2754 self.assertTrue(called) 2755 2756 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2757 def test_library_register_torch_dispatch(self): 2758 for mode in ["function", "qualname", "opoverload"]: 2759 2760 class MyMode(torch.utils._python_dispatch.TorchDispatchMode): 2761 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 2762 return func(*args, **kwargs) 2763 2764 @torch.library.custom_op("_torch_testing::add", mutates_args=()) 2765 def add(x: Tensor, y: float) -> Tensor: 2766 x_np = x.cpu().numpy() 2767 out_np = x_np + y 2768 return torch.from_numpy(out_np).to(x.device) 2769 2770 called = False 2771 2772 if mode == "function": 2773 dec = torch.library.register_torch_dispatch(add, MyMode) 2774 self.assertIsNotNone(dec) 2775 elif mode == "qualname": 2776 dec = torch.library.register_torch_dispatch( 2777 "_torch_testing::add", MyMode 2778 ) 2779 self.assertIsNotNone(dec) 2780 elif mode == "opoverload": 2781 dec = torch.library.register_torch_dispatch( 2782 torch.ops._torch_testing.add.default, MyMode 2783 ) 2784 self.assertIsNotNone(dec) 2785 else: 2786 raise AssertionError("should not get here") 2787 2788 @dec 2789 def _(mode, func, types, args, kwargs): 2790 nonlocal called 2791 called = True 2792 return func(*args, **kwargs) 2793 2794 with MyMode(): 2795 x = torch.randn(3) 2796 y = 3.14 2797 z = add(x, y) 2798 self.assertEqual(z.shape, x.shape) 2799 self.assertTrue(called) 2800 2801 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2802 def test_library_register_torch_dispatch_low_level(self): 2803 modes = ["qualname", "opoverload"] 2804 calls = ["decorator", "function"] 2805 device_types_options = [("cpu", "cuda"), "cpu", None] 2806 2807 for mode, call, device_types in itertools.product( 2808 modes, calls, device_types_options 2809 ): 2810 with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2811 lib.define("add10(Tensor x, float y) -> Tensor") 2812 2813 if mode == "qualname": 2814 op = "_torch_testing::add10" 2815 else: 2816 assert mode == "opoverload" 2817 op = torch.ops._torch_testing.add10.default 2818 2819 called = False 2820 2821 class MyMode(torch.utils._python_dispatch.TorchDispatchMode): 2822 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 2823 return func(*args, **kwargs) 2824 2825 if call == "decorator": 2826 2827 @torch.library.register_torch_dispatch(op, MyMode, lib=lib) 2828 def _(mode, func, types, args, kwargs): 2829 x, y = args 2830 nonlocal called 2831 called = True 2832 return x + y 2833 2834 else: 2835 assert call == "function" 2836 2837 def add_stuff(mode, func, types, args, kwargs): 2838 x, y = args 2839 nonlocal called 2840 called = True 2841 return x + y 2842 2843 torch.library.register_torch_dispatch( 2844 op, MyMode, add_stuff, lib=lib 2845 ) 2846 2847 x = torch.randn(3) 2848 y = 3.14 2849 with MyMode(): 2850 z = torch.ops._torch_testing.add10.default(x, y) 2851 self.assertEqual(z, x + y) 2852 self.assertTrue(called) 2853 2854 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2855 def test_library_register_kernel(self): 2856 modes = ["function", "qualname", "opoverload"] 2857 calls = ["decorator", "function"] 2858 device_types_options = ["cpu", None] 2859 2860 for mode, call, device_types in itertools.product( 2861 modes, calls, device_types_options 2862 ): 2863 2864 @torch.library.custom_op( 2865 "_torch_testing::add", mutates_args=(), device_types="cuda" 2866 ) 2867 def add(x: Tensor, y: float) -> Tensor: 2868 x_np = x.cpu().numpy() 2869 out_np = x_np + y 2870 return torch.from_numpy(out_np).to(x.device) 2871 2872 if mode == "function": 2873 op = add 2874 elif mode == "qualname": 2875 op = "_torch_testing::add" 2876 else: 2877 assert mode == "opoverload" 2878 op = torch.ops._torch_testing.add.default 2879 2880 called = False 2881 2882 if call == "decorator": 2883 2884 @torch.library.register_kernel(op, device_types) 2885 def _(x, y): 2886 nonlocal called 2887 called = True 2888 x_np = x.numpy() 2889 out_np = x_np + y 2890 return torch.from_numpy(out_np) 2891 2892 else: 2893 assert call == "function" 2894 2895 def add_cpu(x, y): 2896 nonlocal called 2897 called = True 2898 x_np = x.numpy() 2899 out_np = x_np + y 2900 return torch.from_numpy(out_np) 2901 2902 torch.library.register_kernel(op, device_types, add_cpu) 2903 2904 x = torch.randn(3) 2905 y = 3.14 2906 z = add(x, y) 2907 self.assertEqual(z, x + y) 2908 self.assertTrue(called) 2909 2910 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2911 def test_library_register_kernel_low_level(self): 2912 modes = ["qualname", "opoverload"] 2913 calls = ["decorator", "function"] 2914 device_types_options = [("cpu", "cuda"), "cpu", None] 2915 2916 for mode, call, device_types in itertools.product( 2917 modes, calls, device_types_options 2918 ): 2919 with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 2920 lib.define("add9(Tensor x, float y) -> Tensor") 2921 2922 if mode == "qualname": 2923 op = "_torch_testing::add9" 2924 else: 2925 assert mode == "opoverload" 2926 op = torch.ops._torch_testing.add9.default 2927 2928 called = False 2929 2930 if call == "decorator": 2931 2932 @torch.library.register_kernel(op, device_types, lib=lib) 2933 def _(x, y): 2934 nonlocal called 2935 called = True 2936 x_np = x.numpy() 2937 out_np = x_np + y 2938 return torch.from_numpy(out_np) 2939 2940 else: 2941 assert call == "function" 2942 2943 def add_cpu(x, y): 2944 nonlocal called 2945 called = True 2946 x_np = x.numpy() 2947 out_np = x_np + y 2948 return torch.from_numpy(out_np) 2949 2950 torch.library.register_kernel(op, device_types, add_cpu, lib=lib) 2951 2952 x = torch.randn(3) 2953 y = 3.14 2954 z = torch.ops._torch_testing.add9.default(x, y) 2955 self.assertEqual(z, x + y) 2956 self.assertTrue(called) 2957 2958 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 2959 def test_library_register_autograd(self): 2960 for mode in ["function", "qualname", "opoverload"]: 2961 2962 @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) 2963 def numpy_sin(x: Tensor) -> Tensor: 2964 x_np = x.cpu().numpy() 2965 y_np = np.sin(x_np) 2966 return torch.from_numpy(y_np).to(device=x.device) 2967 2968 def setup_context(ctx, inputs, output) -> Tensor: 2969 (x,) = inputs 2970 ctx.save_for_backward(x) 2971 2972 called = False 2973 2974 def backward(ctx, grad): 2975 nonlocal called 2976 called = True 2977 (x,) = ctx.saved_tensors 2978 return grad * x.cos() 2979 2980 if mode == "function": 2981 torch.library.register_autograd( 2982 numpy_sin, backward, setup_context=setup_context 2983 ) 2984 elif mode == "qualname": 2985 torch.library.register_autograd( 2986 "mylib::numpy_sin", backward, setup_context=setup_context 2987 ) 2988 elif mode == "opoverload": 2989 torch.library.register_autograd( 2990 torch.ops.mylib.numpy_sin.default, 2991 backward, 2992 setup_context=setup_context, 2993 ) 2994 2995 x = torch.randn(3, requires_grad=True) 2996 y = numpy_sin(x) 2997 (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) 2998 self.assertTrue(called) 2999 self.assertEqual(grad_x, x.cos()) 3000 3001 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3002 def test_library_register_autograd_low_level(self): 3003 for mode in ["qualname", "opoverload"]: 3004 with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 3005 lib.define("sin5(Tensor x) -> Tensor") 3006 3007 def numpy_sin(x: Tensor) -> Tensor: 3008 x_np = x.cpu().detach().numpy() 3009 y_np = np.sin(x_np) 3010 return torch.from_numpy(y_np).to(device=x.device) 3011 3012 def setup_context(ctx, inputs, output) -> Tensor: 3013 (x,) = inputs 3014 ctx.save_for_backward(x) 3015 3016 called = False 3017 3018 def backward(ctx, grad): 3019 nonlocal called 3020 called = True 3021 (x,) = ctx.saved_tensors 3022 return grad * x.cos() 3023 3024 lib.impl("sin5", numpy_sin, "CPU") 3025 3026 called = False 3027 3028 if mode == "qualname": 3029 torch.library.register_autograd( 3030 "_torch_testing::sin5", 3031 backward, 3032 setup_context=setup_context, 3033 lib=lib, 3034 ) 3035 elif mode == "opoverload": 3036 torch.library.register_autograd( 3037 torch.ops._torch_testing.sin5.default, 3038 backward, 3039 setup_context=setup_context, 3040 lib=lib, 3041 ) 3042 x = torch.randn(3, requires_grad=True) 3043 y = torch.ops._torch_testing.sin5(x) 3044 (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) 3045 self.assertTrue(called) 3046 self.assertEqual(grad_x, x.cos()) 3047 3048 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3049 def test_fake(self): 3050 @torch.library.custom_op("_torch_testing::add", mutates_args=()) 3051 def add(x: Tensor, y: float) -> Tensor: 3052 x_np = x.cpu().numpy() 3053 out_np = x_np + y 3054 return torch.from_numpy(out_np).to(x.device) 3055 3056 x = torch.randn(3) 3057 y = 3.14 3058 z = add(x, y) 3059 self.assertEqual(z, x + y) 3060 3061 try: 3062 with torch._subclasses.fake_tensor.FakeTensorMode(): 3063 x = torch.randn(3) 3064 add(x, y) 3065 raise AssertionError("should not be hit") 3066 except RuntimeError as e: 3067 abstract_impl_error_msg = str(e) 3068 abstract_impl_error_msg = re.sub( 3069 r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg 3070 ).replace(". ", ".\n") 3071 self.assertExpectedInline( 3072 abstract_impl_error_msg, 3073 """\ 3074There was no fake impl registered for <CustomOpDef(_torch_testing::add)>. 3075This is necessary for torch.compile/export/fx tracing to work. 3076Please use `add.register_fake` to add an fake impl.""", 3077 ) 3078 3079 if not IS_WINDOWS: 3080 3081 @torch.compile(backend="eager") 3082 def f(x, y): 3083 return add(x, y) 3084 3085 x = torch.randn(3) 3086 with self.assertRaisesRegex(RuntimeError, "no fake impl"): 3087 f(x, y) 3088 3089 abstract_called = False 3090 3091 @add.register_fake 3092 def _(x, y): 3093 nonlocal abstract_called 3094 abstract_called = True 3095 return torch.empty_like(x) 3096 3097 with torch._subclasses.fake_tensor.FakeTensorMode(): 3098 x = torch.randn(3) 3099 z = add(x, y) 3100 self.assertEqual(z.shape, x.shape) 3101 self.assertTrue(abstract_called) 3102 3103 @skipIfTorchDynamo("recursive dynamo") 3104 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") 3105 def test_compile(self): 3106 called_impl = False 3107 called_abstract = False 3108 3109 @torch.library.custom_op("_torch_testing::linear", mutates_args=()) 3110 def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: 3111 nonlocal called_impl 3112 called_impl = True 3113 x_np = x.numpy() 3114 w_np = weight.numpy() 3115 b_np = bias.numpy() 3116 out_np = np.add(x_np @ w_np.T, bias) 3117 return out_np 3118 3119 @custom_linear.register_fake 3120 def _(x, weight, bias): 3121 nonlocal called_abstract 3122 called_abstract = True 3123 assert x.dim() == 2 3124 assert weight.dim() == 2 3125 assert bias.dim() == 1 3126 assert x.shape[1] == weight.shape[1] 3127 assert weight.shape[0] == bias.shape[0] 3128 assert x.device == weight.device 3129 return x.new_empty(x.size(0), weight.size(0)) 3130 3131 x = torch.randn(2, 2) 3132 weight = torch.randn(2, 2) 3133 bias = torch.randn(2) 3134 out = torch.compile(custom_linear, backend="eager", fullgraph=True)( 3135 x, weight, bias 3136 ) 3137 self.assertEqual(out, torch.nn.functional.linear(x, weight, bias)) 3138 self.assertTrue(called_impl) 3139 self.assertTrue(called_abstract) 3140 3141 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3142 def test_register_autograd_error_cases(self): 3143 @torch.library.custom_op("_torch_testing::g", mutates_args=()) 3144 def g(x: Tensor) -> Tensor: 3145 return x.sin() 3146 3147 x = torch.randn(3, requires_grad=True) 3148 y = g(x) 3149 with self.assertRaisesRegex(RuntimeError, "no autograd formula"): 3150 y.sum().backward() 3151 3152 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3153 def test_replacement(self): 3154 @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3155 def f(x: Tensor) -> Tensor: 3156 return x.sin() 3157 3158 x = torch.randn(3) 3159 y = f(x) 3160 self.assertEqual(y, x.sin()) 3161 3162 @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3163 def f(x: Tensor) -> Tensor: 3164 return x.cos() 3165 3166 y = f(x) 3167 self.assertEqual(y, x.cos()) 3168 3169 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3170 @unittest.skipIf(not TEST_CUDA, "requires CUDA") 3171 def test_split_device(self): 3172 cpu_call_count = 0 3173 cuda_call_count = 0 3174 3175 @torch.library.custom_op( 3176 "_torch_testing::f", mutates_args=(), device_types="cpu" 3177 ) 3178 def f(x: Tensor) -> Tensor: 3179 nonlocal cpu_call_count 3180 cpu_call_count += 1 3181 x_np = x.numpy() 3182 out_np = np.sin(x_np) 3183 return torch.from_numpy(out_np) 3184 3185 @f.register_kernel("cuda") 3186 def _(x: Tensor) -> Tensor: 3187 nonlocal cuda_call_count 3188 cuda_call_count += 1 3189 x_np = x.cpu().numpy() 3190 out_np = np.sin(x_np) 3191 return torch.from_numpy(out_np).to(x.device) 3192 3193 x = torch.randn(3) 3194 y = f(x) 3195 self.assertEqual(y, x.sin()) 3196 self.assertEqual(cpu_call_count, 1) 3197 self.assertEqual(cuda_call_count, 0) 3198 3199 x = x.cuda() 3200 y = f(x) 3201 self.assertEqual(y, x.sin()) 3202 self.assertEqual(cpu_call_count, 1) 3203 self.assertEqual(cuda_call_count, 1) 3204 3205 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3206 @unittest.skipIf(not TEST_CUDA, "requires CUDA") 3207 def test_multi_types(self): 3208 @torch.library.custom_op( 3209 "_torch_testing::f", mutates_args=(), device_types=("cpu", "cuda") 3210 ) 3211 def f(x: Tensor) -> Tensor: 3212 x_np = x.cpu().numpy() 3213 out_np = np.sin(x_np) 3214 return torch.from_numpy(out_np).to(x.device) 3215 3216 x = torch.randn(3) 3217 y = f(x) 3218 self.assertEqual(y, x.sin()) 3219 x = x.cuda() 3220 y = f(x) 3221 self.assertEqual(y, x.sin()) 3222 3223 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3224 def test_overloading(self): 3225 called_f = 0 3226 called_f1 = 0 3227 3228 @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3229 def f(x: Tensor) -> Tensor: 3230 nonlocal called_f 3231 called_f += 1 3232 return x.clone() 3233 3234 x = torch.randn(2, 3) 3235 torch.ops._torch_testing.f(x) 3236 self.assertEqual(called_f, 1) 3237 3238 @torch.library.custom_op("_torch_testing::f.overload", mutates_args=()) 3239 def f1(x: Tensor, y: Tensor) -> Tensor: 3240 nonlocal called_f1 3241 called_f1 += 1 3242 return x.clone() 3243 3244 torch.ops._torch_testing.f(x, x) 3245 self.assertEqual(called_f1, 1) 3246 3247 def test_disallows_output_aliasing(self): 3248 @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3249 def f(x: Tensor) -> Tensor: 3250 return x.view(-1) 3251 3252 x = torch.randn(3) 3253 with self.assertRaisesRegex(RuntimeError, "may not alias"): 3254 f(x) 3255 3256 @torch.library.custom_op("_torch_testing::f", mutates_args=()) 3257 def f(x: Tensor) -> Tensor: 3258 return x 3259 3260 x = torch.randn(3) 3261 with self.assertRaisesRegex(RuntimeError, "may not alias"): 3262 f(x) 3263 3264 @torch.library.custom_op( 3265 "_torch_testing::f", mutates_args={"x"}, device_types="cpu" 3266 ) 3267 def numpy_sin_inplace(x: Tensor) -> Tensor: 3268 x_np = x.numpy() 3269 np.sin(x_np, out=x_np) 3270 return x 3271 3272 x = torch.randn(3) 3273 with self.assertRaisesRegex(RuntimeError, "may not alias"): 3274 numpy_sin_inplace(x) 3275 3276 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3277 def test_factory_function(self): 3278 @torch.library.custom_op( 3279 "_torch_testing::f", mutates_args={}, device_types="cpu" 3280 ) 3281 def f(device: torch.device) -> Tensor: 3282 return torch.ones(3) 3283 3284 result = f(device="cpu") 3285 self.assertEqual(result.device, torch.device("cpu")) 3286 self.assertEqual(result, torch.ones(3)) 3287 3288 with self.assertRaisesRegex( 3289 RuntimeError, "f does not have a kernel registered for cuda" 3290 ): 3291 f("cuda") 3292 3293 with self.assertRaisesRegex( 3294 ValueError, 3295 "Functions without tensor inputs are required to have a `device: torch.device` argument", 3296 ): 3297 3298 @torch.library.custom_op( 3299 "_torch_testing::f2", mutates_args={}, device_types="cpu" 3300 ) 3301 def f2() -> Tensor: 3302 return torch.ones(3) 3303 3304 @torch.library.custom_op("_torch_testing::f3", mutates_args={}) 3305 def f3() -> Tensor: 3306 raise NotImplementedError("NYI") 3307 3308 with self.assertRaisesRegex( 3309 ValueError, 3310 "Functions without tensor inputs are required to have a `device: torch.device` argument", 3311 ): 3312 3313 @f3.register_kernel("cpu") 3314 def _(): 3315 return torch.zeros(3) 3316 3317 result = f(x) 3318 3319 @torch.library.custom_op("_torch_testing::f4", mutates_args={}) 3320 def f4(device: torch.device) -> Tensor: 3321 raise NotImplementedError("NYI") 3322 3323 @f4.register_kernel("cpu") 3324 def _(device: torch.device): 3325 return torch.zeros(3) 3326 3327 result = f(device="cpu") 3328 self.assertEqual(result.device, torch.device("cpu")) 3329 self.assertEqual(result, torch.ones(3)) 3330 3331 def test_library_schema_infer(self): 3332 def foo_impl(x: torch.Tensor) -> torch.Tensor: 3333 return x.sin() 3334 3335 schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={}) 3336 self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor") 3337 3338 schema = torch.library.infer_schema(foo_impl, mutates_args={}) 3339 self.assertExpectedInline(schema, "(Tensor x) -> Tensor") 3340 3341 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3342 def test_set_kernel_enabled(self): 3343 x = torch.ones(1) 3344 3345 @torch.library.custom_op("mylib::f", mutates_args=()) 3346 def f(x: Tensor) -> Tensor: 3347 return x + 1 3348 3349 self.assertEqual(f(x), x + 1) 3350 with self.assertLogs("torch._library.custom_ops") as captured: 3351 with f.set_kernel_enabled("gpu", enabled=False): 3352 self.assertEqual(f(x), x + 1) 3353 self.assertIn( 3354 "no kernel was registered for this device type", captured.output[0] 3355 ) 3356 3357 @f.register_kernel("cpu") 3358 def _(x): 3359 return x + 2 3360 3361 self.assertEqual(f(x), x + 2) 3362 3363 with self.assertLogs("torch._library.custom_ops") as captured: 3364 with f.set_kernel_enabled("cpu", enabled=True): 3365 self.assertEqual(f(x), x + 2) 3366 self.assertIn("already enabled", captured.output[0]) 3367 3368 with f.set_kernel_enabled("cpu", enabled=False): 3369 self.assertEqual(f(x), x + 1) 3370 3371 with self.assertLogs("torch._library.custom_ops") as captured: 3372 with f.set_kernel_enabled("cpu", enabled=False): 3373 self.assertEqual(f(x), x + 1) 3374 self.assertIn("already disabled", captured.output[0]) 3375 3376 self.assertEqual(f(x), x + 1) 3377 3378 with f.set_kernel_enabled("cpu", enabled=True): 3379 self.assertEqual(f(x), x + 2) 3380 3381 with f.set_kernel_enabled("cpu", enabled=False): 3382 self.assertEqual(f(x), x + 1) 3383 3384 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3385 def test_register_vmap_kwargonly_low_level(self): 3386 with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 3387 lib.define("foo(Tensor x, *, float y) -> Tensor") 3388 called = False 3389 3390 def foo_impl(x, *, y): 3391 return x * y 3392 3393 lib.impl("foo", foo_impl, "CPU") 3394 3395 def vmap(info, in_dims, x, *, y): 3396 nonlocal called 3397 called = True 3398 return x * y, 0 3399 3400 torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) 3401 3402 x = torch.ones(3) 3403 result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14) 3404 self.assertTrue(called) 3405 self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14])) 3406 3407 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3408 def test_register_vmap_defaults(self): 3409 with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: 3410 lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") 3411 3412 def foo_impl(w, x=2, *, y=3, z): 3413 return w * x * y * z 3414 3415 lib.impl("foo", foo_impl, "CPU") 3416 3417 called = False 3418 3419 def vmap(info, in_dims, w, x=2, *, y=3, z): 3420 nonlocal called 3421 called = True 3422 return w * x * y * z, 0 3423 3424 torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) 3425 3426 w = torch.ones(3) 3427 result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42) 3428 self.assertTrue(called) 3429 self.assertEqual(result, w * 2 * 3 * 42) 3430 3431 def test_layout_constraint_tags(self): 3432 needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order 3433 flexible_layout = torch._C.Tag.flexible_layout 3434 # (tags, the result of the tag inference) 3435 tests = [ 3436 ({needs_fixed_stride_order}, needs_fixed_stride_order), 3437 ({flexible_layout}, flexible_layout), 3438 # If no tags are provided, then the following is the default 3439 (set(), flexible_layout), 3440 # If multiple tags are provided, then we use the most constrained tag. 3441 ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), 3442 ] 3443 from torch._inductor.lowering import get_layout_constraint_tag 3444 3445 for tags, expected in tests: 3446 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 3447 m.define("foobar(Tensor x) -> Tensor", tags=tags) 3448 result = get_layout_constraint_tag(torch.ops.mylib.foobar.default) 3449 self.assertEqual(result, expected) 3450 3451 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3452 def test_library_register_vmap(self): 3453 for mode in ["function", "qualname", "opoverload", "c_opdef"]: 3454 3455 @torch.library.custom_op("mylib::f", mutates_args=()) 3456 def f(x: Tensor, y: Tensor) -> Tensor: 3457 return x * y 3458 3459 called = False 3460 3461 def fvmap(info, in_dims, x, y): 3462 nonlocal called 3463 called = True 3464 x_bdim, y_bdim = in_dims 3465 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3466 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3467 result = x * y 3468 result = result.movedim(-1, 0) 3469 return result, 0 3470 3471 if mode == "function": 3472 torch.library.register_vmap(f, fvmap) 3473 elif mode == "qualname": 3474 torch.library.register_vmap("mylib::f", fvmap) 3475 elif mode == "opoverload": 3476 torch.library.register_vmap(torch.ops.mylib.f.default, fvmap) 3477 elif mode == "c_opdef": 3478 f.register_vmap(fvmap) 3479 3480 x = torch.randn(2, 2) 3481 y = torch.randn(2, 2) 3482 3483 result = torch.vmap(f)(x, y) 3484 self.assertTrue(called) 3485 self.assertEqual(result, x * y) 3486 3487 called = False 3488 result = torch.vmap(f, out_dims=1)(x, y) 3489 self.assertEqual(result, (x * y).T) 3490 self.assertTrue(called) 3491 3492 called = False 3493 result = torch.vmap(f, in_dims=1)(x, y) 3494 self.assertEqual(result, (x * y).T) 3495 self.assertTrue(called) 3496 3497 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3498 def test_library_register_vmap_library_decorator(self): 3499 @torch.library.custom_op("mylib::f", mutates_args=()) 3500 def f(x: Tensor, y: Tensor) -> Tensor: 3501 return x * y 3502 3503 called = False 3504 3505 @torch.library.register_vmap("mylib::f") 3506 def fvmap(info, in_dims, x, y): 3507 nonlocal called 3508 called = True 3509 x_bdim, y_bdim = in_dims 3510 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3511 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3512 result = x * y 3513 result = result.movedim(-1, 0) 3514 return result, 0 3515 3516 x = torch.randn(2, 2) 3517 y = torch.randn(2, 2) 3518 3519 result = torch.vmap(f)(x, y) 3520 self.assertTrue(called) 3521 self.assertEqual(result, x * y) 3522 3523 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3524 def test_library_register_vmap_op_decorator(self): 3525 @torch.library.custom_op("mylib::f", mutates_args=()) 3526 def f(x: Tensor, y: Tensor) -> Tensor: 3527 return x * y 3528 3529 called = False 3530 3531 @f.register_vmap 3532 def fvmap(info, in_dims, x, y): 3533 nonlocal called 3534 called = True 3535 x_bdim, y_bdim = in_dims 3536 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3537 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3538 result = x * y 3539 result = result.movedim(-1, 0) 3540 return result, 0 3541 3542 x = torch.randn(2, 2) 3543 y = torch.randn(2, 2) 3544 3545 result = torch.vmap(f)(x, y) 3546 self.assertTrue(called) 3547 self.assertEqual(result, x * y) 3548 3549 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3550 def test_library_register_vmap_register_multiple_times(self): 3551 @torch.library.custom_op("mylib::f", mutates_args=()) 3552 def f(x: Tensor, y: Tensor) -> Tensor: 3553 return x * y 3554 3555 called = False 3556 3557 @f.register_vmap 3558 def fvmap(info, in_dims, x, y): 3559 nonlocal called 3560 called = True 3561 x_bdim, y_bdim = in_dims 3562 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3563 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3564 result = x * y 3565 result = result.movedim(-1, 0) 3566 return result, 0 3567 3568 x = torch.randn(2, 2) 3569 y = torch.randn(2, 2) 3570 3571 result = torch.vmap(f)(x, y) 3572 self.assertTrue(called) 3573 self.assertEqual(result, x * y) 3574 called = False 3575 3576 @f.register_vmap 3577 def fvmap2(info, in_dims, x, y): 3578 nonlocal called 3579 called = True 3580 x_bdim, y_bdim = in_dims 3581 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3582 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3583 result = x + y 3584 result = result.movedim(-1, 0) 3585 return result, 0 3586 3587 result = torch.vmap(f)(x, y) 3588 self.assertTrue(called) 3589 self.assertEqual(result, x + y) 3590 3591 @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") 3592 def test_library_register_vmap_register_multiple_times_2(self): 3593 @torch.library.custom_op("mylib::f", mutates_args=()) 3594 def f(x: Tensor, y: Tensor) -> Tensor: 3595 return x * y 3596 3597 called = False 3598 3599 @torch.library.register_vmap("mylib::f") 3600 def fvmap(info, in_dims, x, y): 3601 nonlocal called 3602 called = True 3603 x_bdim, y_bdim = in_dims 3604 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3605 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3606 result = x * y 3607 result = result.movedim(-1, 0) 3608 return result, 0 3609 3610 x = torch.randn(2, 2) 3611 y = torch.randn(2, 2) 3612 3613 result = torch.vmap(f)(x, y) 3614 self.assertTrue(called) 3615 self.assertEqual(result, x * y) 3616 called = False 3617 3618 @torch.library.register_vmap("mylib::f") 3619 def fvmap2(info, in_dims, x, y): 3620 nonlocal called 3621 called = True 3622 x_bdim, y_bdim = in_dims 3623 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 3624 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 3625 result = x + y 3626 result = result.movedim(-1, 0) 3627 return result, 0 3628 3629 result = torch.vmap(f)(x, y) 3630 self.assertTrue(called) 3631 self.assertEqual(result, x + y) 3632 3633 3634class MiniOpTestOther(CustomOpTestCaseBase): 3635 test_ns = "mini_op_test" 3636 3637 def test_nonzero_again(self): 3638 x = torch.tensor([0, 1, 2, 0, 0]) 3639 y = torch.ops.aten.nonzero.default(x) 3640 self.assertEqual(y, torch.tensor([[1], [2]])) 3641 3642 3643optests.generate_opcheck_tests( 3644 MiniOpTest, 3645 ["aten", "mini_op_test"], 3646 get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"), 3647 additional_decorators={ 3648 "test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure] 3649 }, 3650 test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS, 3651) 3652 3653optests.generate_opcheck_tests( 3654 MiniOpTestOther, 3655 ["aten", "mini_op_test"], 3656 get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"), 3657 test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS, 3658) 3659 3660 3661class TestGenerateOpcheckTests(CustomOpTestCaseBase): 3662 def test_MiniOpTest(self): 3663 for orig_test in ["test_mm", "test_nonzero"]: 3664 for ( 3665 test 3666 ) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS: 3667 expected_test = f"{test}__{orig_test}" 3668 self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test) 3669 3670 def test_generate_repro_save_data(self): 3671 from torch.testing._internal.optests.generate_tests import generate_repro 3672 3673 args = (torch.ones(2, 2),) 3674 kwargs = {"mat2": torch.zeros(2, 2)} 3675 actual = generate_repro( 3676 "test_schema", 3677 torch.ops.aten.sin.default, 3678 args, 3679 kwargs, 3680 save_data=True, 3681 dry_run=True, 3682 ) 3683 actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual) 3684 self.assertExpectedInline( 3685 actual, 3686 """\ 3687# ========================================================= 3688# BEGIN REPRO SCRIPT 3689# ========================================================= 3690import torch 3691from torch.testing._internal.optests import opcheck 3692 3693# Make sure you have loaded the library that contains the op 3694# via an import or torch.ops.load_library(...) 3695op = torch.ops.aten.sin.default 3696 3697args, kwargs = torch.load("repro.pt") 3698opcheck(op, args, kwargs, test_utils="test_schema") 3699# ========================================================= 3700# END REPRO SCRIPT 3701# ========================================================= 3702""", 3703 ) 3704 3705 def test_generate_repro_no_save_data(self): 3706 from torch.testing._internal.optests.generate_tests import generate_repro 3707 3708 args = (torch.ones(2, 2),) 3709 kwargs = {"mat2": torch.zeros(2, 2)} 3710 actual = generate_repro( 3711 "test_schema", 3712 torch.ops.aten.sin.default, 3713 args, 3714 kwargs, 3715 save_data=False, 3716 dry_run=True, 3717 ) 3718 self.assertExpectedInline( 3719 actual, 3720 """\ 3721# ========================================================= 3722# BEGIN REPRO SCRIPT 3723# ========================================================= 3724import torch 3725from torch.testing._internal.optests import opcheck 3726 3727# Make sure you have loaded the library that contains the op 3728# via an import or torch.ops.load_library(...) 3729op = torch.ops.aten.sin.default 3730 3731# If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 3732# we will fill them in same (args, kwargs) as in your test 3733args = () # args to the operator 3734kwargs = {} # kwargs to the operator 3735opcheck(op, args, kwargs, test_utils="test_schema") 3736# ========================================================= 3737# END REPRO SCRIPT 3738# ========================================================= 3739""", 3740 ) 3741 3742 def test_failures_dict_validation(self): 3743 from torch.testing._internal.optests.generate_tests import ( 3744 FailuresDict, 3745 validate_failures_dict_structure, 3746 ) 3747 3748 failures = { 3749 "mini_op_test::incorrect_schema": { 3750 "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error": { 3751 "comment": "", 3752 "status": "success", 3753 } 3754 } 3755 } 3756 with self.assertRaisesRegex(RuntimeError, "got status=success"): 3757 validate_failures_dict_structure( 3758 FailuresDict("", failures), 3759 torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, 3760 MiniOpTest, 3761 ) 3762 3763 failures = { 3764 "mini_op_test::incorrect_schema": { 3765 "MiniOpTest.test_aot_dispatch__test_delayed_error": { 3766 "comment": "", 3767 "status": "xfail", 3768 }, 3769 } 3770 } 3771 with self.assertRaisesRegex(RuntimeError, "should begin with one of"): 3772 validate_failures_dict_structure( 3773 FailuresDict("", failures), 3774 torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, 3775 MiniOpTest, 3776 ) 3777 3778 failures = { 3779 "mini_op_test::incorrect_schema": { 3780 "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error_nopenopenope": { 3781 "comment": "", 3782 "status": "xfail", 3783 }, 3784 } 3785 } 3786 with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"): 3787 validate_failures_dict_structure( 3788 FailuresDict("", failures), 3789 torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, 3790 MiniOpTest, 3791 ) 3792 3793 def test_dont_generate_decorator(self): 3794 self.assertTrue(hasattr(MiniOpTest, "test_dont_generate")) 3795 self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate")) 3796 3797 def test_opcheck(self): 3798 x = torch.randn(3, requires_grad=True) 3799 with self.assertRaisesRegex(ValueError, "OpOverload"): 3800 torch.library.opcheck(torch.sin, (x,)) 3801 with self.assertRaisesRegex(ValueError, "test_utils to be subset of"): 3802 torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah") 3803 result = torch.library.opcheck(torch.ops.aten.sin.default, (x,)) 3804 3805 self.assertEqual( 3806 result, 3807 { 3808 "test_schema": "SUCCESS", 3809 "test_autograd_registration": "SUCCESS", 3810 "test_faketensor": "SUCCESS", 3811 "test_aot_dispatch_dynamic": "SUCCESS", 3812 }, 3813 ) 3814 3815 result = torch.library.opcheck( 3816 torch.ops.aten.sin.default, (x,), test_utils="test_schema" 3817 ) 3818 self.assertEqual(result, {"test_schema": "SUCCESS"}) 3819 3820 result = torch.library.opcheck( 3821 torch.ops.aten.sin.default, 3822 (x,), 3823 test_utils=["test_schema", "test_faketensor"], 3824 ) 3825 self.assertEqual( 3826 result, 3827 { 3828 "test_schema": "SUCCESS", 3829 "test_faketensor": "SUCCESS", 3830 }, 3831 ) 3832 3833 def test_opcheck_customopdef(self): 3834 sample_inputs = [ 3835 (torch.randn(3),), 3836 (torch.randn(3, requires_grad=True),), 3837 ] 3838 if torch.cuda.is_available(): 3839 sample_inputs.extend( 3840 [ 3841 (torch.randn(3, device="cuda"),), 3842 (torch.randn(3, device="cuda", requires_grad=True),), 3843 ] 3844 ) 3845 for args in sample_inputs: 3846 torch.library.opcheck(custom_op_db.numpy_cube, args) 3847 3848 def test_is_inside_opcheck_mode(self): 3849 self.assertFalse(optests.is_inside_opcheck_mode()) 3850 with optests.generate_tests.OpCheckMode( 3851 ["foo"], "bar", lambda x: x, None, "baz", "brr" 3852 ): 3853 self.assertTrue(optests.is_inside_opcheck_mode()) 3854 3855 def test_opcheck_bad_op(self): 3856 op = op_with_incorrect_schema(self, "foo") 3857 x = torch.randn(3) 3858 with self.assertRaisesRegex(Exception, "is not defined to alias output"): 3859 torch.library.opcheck(op, (x,)) 3860 3861 result = torch.library.opcheck(op, (x,), raise_exception=False) 3862 self.assertTrue(isinstance(result["test_schema"], RuntimeError)) 3863 del result["test_schema"] 3864 self.assertEqual( 3865 result, 3866 { 3867 "test_autograd_registration": "SUCCESS", 3868 "test_faketensor": "SUCCESS", 3869 "test_aot_dispatch_dynamic": "SUCCESS", 3870 }, 3871 ) 3872 3873 def test_opcheck_does_not_require_extra_deps(self): 3874 # torch.testing._internal.common_utils comes with a lot of additional 3875 # test-time dependencies. Since opcheck is public API, it should be 3876 # usable only with pytorch install-time dependencies. 3877 cmd = [ 3878 sys.executable, 3879 "-c", 3880 "import torch; import sys; \ 3881 x = torch.randn(3, requires_grad=True); \ 3882 torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \ 3883 assert 'expecttest' not in sys.modules; \ 3884 assert 'torch.testing._internal.common_utils' not in sys.modules", 3885 ] 3886 subprocess.check_output(cmd, shell=False) 3887 3888 3889class TestTypeConversion(TestCase): 3890 """In infer_schema(), we try to suggest a correct type when the type annotation is wrong.""" 3891 3892 def setUp(self): 3893 self.supported_base_types = [ 3894 int, 3895 float, 3896 bool, 3897 str, 3898 torch.device, 3899 torch.Tensor, 3900 torch.dtype, 3901 torch.types.Number, 3902 ] 3903 3904 def test_simple_tuple(self): 3905 self.assertEqual(List, tuple_to_list(Tuple)) 3906 3907 def test_supported_types(self): 3908 for t in self.supported_base_types: 3909 result_type = tuple_to_list(Tuple[t, t, t]) 3910 self.assertEqual(result_type, List[t]) 3911 3912 result_type = tuple_to_list(Tuple[t]) 3913 self.assertEqual(result_type, List[t]) 3914 3915 def test_optional(self): 3916 for t in self.supported_base_types: 3917 result_type = tuple_to_list(Tuple[t, Optional[t]]) 3918 self.assertEqual(result_type, List[Optional[t]]) 3919 3920 result_type = tuple_to_list(Tuple[t, t, Optional[t]]) 3921 self.assertEqual(result_type, List[Optional[t]]) 3922 3923 result_type = tuple_to_list(Tuple[t, ...]) 3924 self.assertEqual(result_type, List[t]) 3925 3926 def test_mixed_types(self): 3927 result_type = tuple_to_list(Tuple[int, float]) 3928 self.assertEqual(result_type, List[typing.Union[int, float]]) 3929 3930 result_type = tuple_to_list(Tuple[int, float, str]) 3931 self.assertEqual(result_type, List[typing.Union[int, float, str]]) 3932 3933 3934only_for = ("cpu", "cuda") 3935instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) 3936instantiate_parametrized_tests(TestCustomOp) 3937instantiate_parametrized_tests(TestCustomOpAPI) 3938 3939if __name__ == "__main__": 3940 run_tests() 3941