1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6import unittest 7from typing import Any 8 9import torch 10import torch.nn as nn 11from torch.testing import FileCheck 12 13 14# Make the helper files in test/ importable 15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 16sys.path.append(pytorch_test_dir) 17from typing import Dict, Iterable, List, Optional, Tuple 18 19import torch.testing._internal.jit_utils 20from torch.testing._internal.common_utils import IS_SANDCASTLE, skipIfTorchDynamo 21from torch.testing._internal.jit_utils import JitTestCase, make_global 22 23 24if __name__ == "__main__": 25 raise RuntimeError( 26 "This test file is not meant to be run directly, use:\n\n" 27 "\tpython test/test_jit.py TESTNAME\n\n" 28 "instead." 29 ) 30 31 32class TestClassType(JitTestCase): 33 def test_reference_semantics(self): 34 """ 35 Test that modifications made to a class instance in TorchScript 36 are visible in eager. 37 """ 38 39 class Foo: 40 def __init__(self, a: int): 41 self.a = a 42 43 def set_a(self, value: int): 44 self.a = value 45 46 def get_a(self) -> int: 47 return self.a 48 49 @property 50 def attr(self): 51 return self.a 52 53 make_global(Foo) # see [local resolution in python] 54 55 def test_fn(obj: Foo): 56 obj.set_a(2) 57 58 scripted_fn = torch.jit.script(test_fn) 59 obj = torch.jit.script(Foo(1)) 60 self.assertEqual(obj.get_a(), 1) 61 self.assertEqual(obj.attr, 1) 62 63 scripted_fn(obj) 64 65 self.assertEqual(obj.get_a(), 2) 66 self.assertEqual(obj.attr, 2) 67 68 def test_get_with_method(self): 69 class FooTest: 70 def __init__(self, x): 71 self.foo = x 72 73 def getFooTest(self): 74 return self.foo 75 76 def fn(x): 77 foo = FooTest(x) 78 return foo.getFooTest() 79 80 input = torch.ones(2, 3) 81 self.assertEqual(fn(input), input) 82 83 def test_get_attr(self): 84 class FooTest: # noqa: B903 85 def __init__(self, x): 86 self.foo = x 87 88 @torch.jit.script 89 def fn(x): 90 foo = FooTest(x) 91 return foo.foo 92 93 input = torch.ones(2, 3) 94 self.assertEqual(fn(input), input) 95 96 def test_in(self): 97 class FooTest: # noqa: B903 98 def __init__(self) -> None: 99 pass 100 101 def __contains__(self, key: str) -> bool: 102 return key == "hi" 103 104 @torch.jit.script 105 def fn(): 106 foo = FooTest() 107 return "hi" in foo, "no" in foo 108 109 self.assertEqual(fn(), (True, False)) 110 111 def test_set_attr_in_method(self): 112 class FooTest: 113 def __init__(self, x: int) -> None: 114 self.foo = x 115 116 def incFooTest(self, y: int) -> None: 117 self.foo = self.foo + y 118 119 @torch.jit.script 120 def fn(x: int) -> int: 121 foo = FooTest(x) 122 foo.incFooTest(2) 123 return foo.foo 124 125 self.assertEqual(fn(1), 3) 126 127 def test_set_attr_type_mismatch(self): 128 with self.assertRaisesRegexWithHighlight( 129 RuntimeError, "Wrong type for attribute assignment", "self.foo = 10" 130 ): 131 132 @torch.jit.script 133 class FooTest: 134 def __init__(self, x): 135 self.foo = x 136 self.foo = 10 # should error since int != Tensor 137 138 def test_get_attr_not_initialized(self): 139 with self.assertRaisesRegexWithHighlight( 140 RuntimeError, "object has no attribute or method", "self.asdf" 141 ): 142 143 @torch.jit.script 144 class FooTest: 145 def __init__(self, x): 146 self.foo = x 147 148 def get_non_initialized(self): 149 return self.asdf # asdf isn't an attr 150 151 def test_set_attr_non_initialized(self): 152 with self.assertRaisesRegexWithHighlight( 153 RuntimeError, "Tried to set nonexistent attribute", "self.bar = y" 154 ): 155 156 @torch.jit.script 157 class FooTest: 158 def __init__(self, x): 159 self.foo = x 160 161 def set_non_initialized(self, y): 162 self.bar = y # can't assign to non-initialized attr 163 164 def test_schema_human_readable(self): 165 """ 166 Make sure that the schema is human readable, ie the mode parameter should read "nearest" instead of being displayed in octal 167 aten::__interpolate(Tensor input, int? size=None, float[]? scale_factor=None, 168 str mode='\156\145\141\162\145\163\164', bool? align_corners=None) -> (Tensor): 169 Expected a value of type 'Optional[int]' for argument 'size' but instead found type 'Tensor'. 170 """ 171 with self.assertRaisesRegexWithHighlight(RuntimeError, "nearest", ""): 172 173 @torch.jit.script 174 def FooTest(x): 175 return torch.nn.functional.interpolate(x, "bad") 176 177 def test_type_annotations(self): 178 with self.assertRaisesRegexWithHighlight( 179 RuntimeError, "Expected a value of type 'bool", "" 180 ): 181 182 @torch.jit.script # noqa: B903 183 class FooTest: # noqa: B903 184 def __init__(self, x: bool) -> None: 185 self.foo = x 186 187 @torch.jit.script 188 def fn(x): 189 FooTest(x) 190 191 fn(2) 192 193 def test_conditional_set_attr(self): 194 with self.assertRaisesRegexWithHighlight( 195 RuntimeError, "assignment cannot be in a control-flow block", "" 196 ): 197 198 @torch.jit.script 199 class FooTest: 200 def __init__(self, x): 201 if 1 == 1: 202 self.attr = x 203 204 def test_class_type_as_param(self): 205 class FooTest: # noqa: B903 206 def __init__(self, x): 207 self.attr = x 208 209 make_global(FooTest) # see [local resolution in python] 210 211 @torch.jit.script 212 def fn(foo: FooTest) -> torch.Tensor: 213 return foo.attr 214 215 @torch.jit.script 216 def fn2(x): 217 foo = FooTest(x) 218 return fn(foo) 219 220 input = torch.ones(1) 221 self.assertEqual(fn2(input), input) 222 223 def test_out_of_order_methods(self): 224 class FooTest: 225 def __init__(self, x): 226 self.x = x 227 self.x = self.get_stuff(x) 228 229 def get_stuff(self, y): 230 return self.x + y 231 232 @torch.jit.script 233 def fn(x): 234 f = FooTest(x) 235 return f.x 236 237 input = torch.ones(1) 238 self.assertEqual(fn(input), input + input) 239 240 def test_save_load_with_classes(self): 241 class FooTest: 242 def __init__(self, x): 243 self.x = x 244 245 def get_x(self): 246 return self.x 247 248 class MyMod(torch.jit.ScriptModule): 249 @torch.jit.script_method 250 def forward(self, a): 251 foo = FooTest(a) 252 return foo.get_x() 253 254 m = MyMod() 255 256 buffer = io.BytesIO() 257 torch.jit.save(m, buffer) 258 259 # classes are globally registered for now, so we need to clear the JIT 260 # registry to simulate loading a new model 261 262 buffer.seek(0) 263 m_loaded = torch.jit.load(buffer) 264 265 input = torch.rand(2, 3) 266 output = m_loaded(input) 267 self.assertEqual(input, output) 268 269 def test_save_load_with_classes_returned(self): 270 class FooTest: 271 def __init__(self, x): 272 self.x = x 273 274 def clone(self): 275 clone = FooTest(self.x) 276 return clone 277 278 class MyMod(torch.jit.ScriptModule): 279 @torch.jit.script_method 280 def forward(self, a): 281 foo = FooTest(a) 282 foo_clone = foo.clone() 283 return foo_clone.x 284 285 m = MyMod() 286 287 buffer = io.BytesIO() 288 torch.jit.save(m, buffer) 289 290 # classes are globally registered for now, so we need to clear the JIT 291 # registry to simulate loading a new model 292 torch.testing._internal.jit_utils.clear_class_registry() 293 294 buffer.seek(0) 295 m_loaded = torch.jit.load(buffer) 296 297 input = torch.rand(2, 3) 298 output = m_loaded(input) 299 self.assertEqual(input, output) 300 301 def test_save_load_with_classes_nested(self): 302 class FooNestedTest: # noqa: B903 303 def __init__(self, y): 304 self.y = y 305 306 class FooNestedTest2: 307 def __init__(self, y): 308 self.y = y 309 self.nested = FooNestedTest(y) 310 311 class FooTest: 312 def __init__(self, x): 313 self.class_attr = FooNestedTest(x) 314 self.class_attr2 = FooNestedTest2(x) 315 self.x = self.class_attr.y + self.class_attr2.y 316 317 class MyMod(torch.jit.ScriptModule): 318 @torch.jit.script_method 319 def forward(self, a): 320 foo = FooTest(a) 321 return foo.x 322 323 m = MyMod() 324 325 buffer = io.BytesIO() 326 torch.jit.save(m, buffer) 327 328 # classes are globally registered for now, so we need to clear the JIT 329 # registry to simulate loading a new model 330 torch.testing._internal.jit_utils.clear_class_registry() 331 332 buffer.seek(0) 333 m_loaded = torch.jit.load(buffer) 334 335 input = torch.rand(2, 3) 336 output = m_loaded(input) 337 self.assertEqual(2 * input, output) 338 339 def test_python_interop(self): 340 class Foo: # noqa: B903 341 def __init__(self, x, y): 342 self.x = x 343 self.y = y 344 345 make_global(Foo) # see [local resolution in python] 346 347 @torch.jit.script 348 def use_foo(foo: Foo) -> Foo: 349 return foo 350 351 # create from python 352 x = torch.ones(2, 3) 353 y = torch.zeros(2, 3) 354 f = Foo(x, y) 355 356 self.assertEqual(x, f.x) 357 self.assertEqual(y, f.y) 358 359 # pass in and out of script 360 f2 = use_foo(f) 361 362 self.assertEqual(x, f2.x) 363 self.assertEqual(y, f2.y) 364 365 def test_class_specialization(self): 366 class Foo: # noqa: B903 367 def __init__(self, x, y): 368 self.x = x 369 self.y = y 370 371 make_global(Foo) # see [local resolution in python] 372 373 def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor: 374 a, b = tup 375 return foo.x + foo2.y + a.x + b.y 376 377 # create from python 378 x = torch.ones(2, 3) 379 y = torch.zeros(2, 3) 380 f = Foo(x, y) 381 f2 = Foo(x * 2, y * 3) 382 f3 = Foo(x * 4, y * 4) 383 384 input = (f, f2, (f, f3)) 385 sfoo = self.checkScript(use_foo, input) 386 graphstr = str(sfoo.graph_for(*input)) 387 FileCheck().check_count("prim::GetAttr", 4).run(graphstr) 388 389 def test_class_sorting(self): 390 class Foo: # noqa: B903 391 def __init__(self, x: int) -> None: 392 self.x = x 393 394 def __lt__(self, other) -> bool: 395 # type: (Foo) -> bool 396 return self.x < other.x 397 398 def getVal(self): 399 return self.x 400 401 make_global(Foo) # see [local resolution in python] 402 403 def test(li: List[Foo], reverse: bool = False) -> Tuple[List[int], List[int]]: 404 li_sorted = sorted(li) 405 ret_sorted = torch.jit.annotate(List[int], []) 406 for foo in li_sorted: 407 ret_sorted.append(foo.getVal()) 408 409 li.sort(reverse=reverse) 410 ret_sort = torch.jit.annotate(List[int], []) 411 for foo in li: 412 ret_sort.append(foo.getVal()) 413 return ret_sorted, ret_sort 414 415 self.checkScript(test, ([Foo(2), Foo(1), Foo(3)],)) 416 self.checkScript(test, ([Foo(2), Foo(1), Foo(3)], True)) 417 self.checkScript(test, ([Foo(2)],)) 418 self.checkScript(test, ([],)) 419 420 @torch.jit.script 421 def test_list_no_reverse(): 422 li = [Foo(3), Foo(1)] 423 li.sort() 424 return li[0].getVal() 425 426 self.assertEqual(test_list_no_reverse(), 1) 427 428 @torch.jit.script 429 def test_sorted_copies(): 430 li = [Foo(3), Foo(1)] 431 li_sorted = sorted(li) 432 return li[0].getVal(), li_sorted[0].getVal() 433 434 self.assertEqual(test_sorted_copies(), (3, 1)) 435 436 @torch.jit.script 437 def test_nested_inside_tuple(): 438 li = [(1, Foo(12)), (1, Foo(11))] 439 li.sort() 440 return [(li[0][0], li[0][1].getVal()), (li[1][0], li[1][1].getVal())] 441 442 self.assertEqual(test_nested_inside_tuple(), [(1, 11), (1, 12)]) 443 444 with self.assertRaisesRegexWithHighlight( 445 RuntimeError, "bool' for argument 'reverse", "" 446 ): 447 448 @torch.jit.script 449 def test(): 450 li = [Foo(1)] 451 li.sort(li) 452 return li 453 454 test() 455 456 with self.assertRaisesRegexWithHighlight( 457 RuntimeError, "must define a __lt__", "" 458 ): 459 460 @torch.jit.script 461 class NoMethod: 462 def __init__(self) -> None: 463 pass 464 465 @torch.jit.script 466 def test(): 467 li = [NoMethod(), NoMethod()] 468 li.sort() 469 return li 470 471 test() 472 473 @torch.jit.script 474 class WrongLt: 475 def __init__(self) -> None: 476 pass 477 478 # lt method defined with the wrong signature 479 def __lt__(self, other): 480 pass 481 482 with self.assertRaisesRegexWithHighlight( 483 RuntimeError, "must define a __lt__", "" 484 ): 485 486 @torch.jit.script 487 def test(): 488 li = [WrongLt(), WrongLt()] 489 li.sort() 490 return li 491 492 test() 493 494 def test_class_inheritance(self): 495 @torch.jit.script 496 class Base: 497 def __init__(self) -> None: 498 self.b = 2 499 500 def two(self, x): 501 return x + self.b 502 503 with self.assertRaisesRegexWithHighlight( 504 RuntimeError, "does not support inheritance", "" 505 ): 506 507 @torch.jit.script 508 class Derived(Base): 509 def two(self, x): 510 return x + self.b + 2 511 512 def test_class_inheritance_implicit(self): 513 """ 514 Test that inheritance is detected in 515 implicit scripting codepaths (e.g. try_ann_to_type). 516 """ 517 518 class A: 519 def __init__(self, t): 520 self.t = t 521 522 @staticmethod 523 def f(a: torch.Tensor): 524 return A(a + 1) 525 526 class B(A): 527 def __init__(self, t): 528 self.t = t + 10 529 530 @staticmethod 531 def f(a: torch.Tensor): 532 return A(a + 1) 533 534 x = A(torch.tensor([3])) 535 536 def fun(x: Any): 537 if isinstance(x, A): 538 return A.f(x.t) 539 else: 540 return B.f(x.t) 541 542 with self.assertRaisesRegexWithHighlight( 543 RuntimeError, "object has no attribute or method", "" 544 ): 545 sc = torch.jit.script(fun) 546 547 @skipIfTorchDynamo("Test does not work with TorchDynamo") 548 @unittest.skipIf(IS_SANDCASTLE, "Importing like this doesn't work in fbcode") 549 def test_imported_classes(self): 550 import jit._imported_class_test.bar 551 import jit._imported_class_test.foo 552 import jit._imported_class_test.very.very.nested 553 554 class MyMod(torch.jit.ScriptModule): 555 @torch.jit.script_method 556 def forward(self, a): 557 foo = jit._imported_class_test.foo.FooSameName(a) 558 bar = jit._imported_class_test.bar.FooSameName(a) 559 three = jit._imported_class_test.very.very.nested.FooUniqueName(a) 560 return foo.x + bar.y + three.y 561 562 m = MyMod() 563 564 buffer = io.BytesIO() 565 torch.jit.save(m, buffer) 566 567 # classes are globally registered for now, so we need to clear the JIT 568 # registry to simulate loading a new model 569 torch.testing._internal.jit_utils.clear_class_registry() 570 571 buffer.seek(0) 572 m_loaded = torch.jit.load(buffer) 573 574 input = torch.rand(2, 3) 575 output = m_loaded(input) 576 self.assertEqual(3 * input, output) 577 578 def test_interface(self): 579 @torch.jit.script 580 class Foo: 581 def __init__(self) -> None: 582 pass 583 584 def one(self, x, y): 585 return x + y 586 587 def two(self, x): 588 return 2 * x 589 590 @torch.jit.script 591 class Bar: 592 def __init__(self) -> None: 593 pass 594 595 def one(self, x, y): 596 return x * y 597 598 def two(self, x): 599 return 2 / x 600 601 @torch.jit.interface 602 class OneTwo: 603 def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 604 pass 605 606 def two(self, x: torch.Tensor) -> torch.Tensor: 607 pass 608 609 @torch.jit.interface 610 class OneTwoThree: 611 def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 612 pass 613 614 def two(self, x: torch.Tensor) -> torch.Tensor: 615 pass 616 617 def three(self, x: torch.Tensor) -> torch.Tensor: 618 pass 619 620 @torch.jit.interface 621 class OneTwoWrong: 622 def one(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 623 pass 624 625 def two(self, x: int) -> int: 626 pass 627 628 @torch.jit.script 629 class NotMember: 630 def __init__(self) -> None: 631 pass 632 633 def one(self, x, y): 634 return x + y 635 636 # missing two 637 638 @torch.jit.script 639 class NotMember2: 640 def __init__(self) -> None: 641 pass 642 643 def one(self, x, y): 644 return x + y 645 646 def two(self, x: int) -> int: 647 return 3 648 649 make_global(Foo, Bar, OneTwo, OneTwoThree, OneTwoWrong, NotMember, NotMember2) 650 651 def use_them(x): 652 a = Foo() 653 b = Bar() 654 c = torch.jit.annotate(List[OneTwo], [a, b]) 655 for i in range(len(c)): 656 x = c[i].one(x, x) 657 x = c[i].two(x) 658 return x 659 660 self.checkScript(use_them, (torch.rand(3, 4),)) 661 662 @torch.jit.script 663 def as_interface(x: OneTwo) -> OneTwo: 664 return x 665 666 @torch.jit.script 667 def inherit(x: OneTwoThree) -> OneTwo: 668 return as_interface(x) 669 670 with self.assertRaisesRegexWithHighlight( 671 RuntimeError, "does not have method", "" 672 ): 673 674 @torch.jit.script 675 def wrong1(): 676 return as_interface(NotMember()) 677 678 with self.assertRaisesRegexWithHighlight( 679 RuntimeError, "is not compatible with interface", "" 680 ): 681 682 @torch.jit.script 683 def wrong2(): 684 return as_interface(NotMember2()) 685 686 with self.assertRaisesRegexWithHighlight( 687 RuntimeError, "does not have method", "" 688 ): 689 690 @torch.jit.script 691 def wrong3(): 692 return inherit(as_interface(Foo())) 693 694 with self.assertRaisesRegexWithHighlight( 695 RuntimeError, "is not compatible with interface", "" 696 ): 697 698 @torch.jit.script 699 def wrong4(x: OneTwoWrong) -> int: 700 return as_interface(x) 701 702 # Test interface/class python assignment 703 class TestPyAssign(nn.Module): 704 def __init__(self) -> None: 705 super().__init__() 706 self.proxy_mod = Foo() 707 708 def forward(self, x): 709 return self.proxy_mod.two(x) 710 711 TestPyAssign.__annotations__ = {"proxy_mod": OneTwo} 712 713 input = torch.rand(3, 4) 714 scripted_pyassign_mod = torch.jit.script(TestPyAssign()) 715 imported_mod = self.getExportImportCopy(scripted_pyassign_mod) 716 self.assertEqual(scripted_pyassign_mod(input), imported_mod(input)) 717 718 class TestPyAssignError(nn.Module): 719 def __init__(self, obj): 720 super().__init__() 721 self.proxy_mod = obj 722 723 def forward(self, x): 724 return self.proxy_mod.two(x) 725 726 TestPyAssignError.__annotations__ = {"proxy_mod": OneTwoThree} 727 728 with self.assertRaisesRegexWithHighlight( 729 RuntimeError, "is not compatible with interface __torch__", "" 730 ): 731 torch.jit.script(TestPyAssignError(Foo())) 732 733 # test pure python object assignment to interface fails 734 class PyClass: 735 def __init__(self) -> None: 736 pass 737 738 with self.assertRaisesRegexWithHighlight( 739 RuntimeError, "the value is not a TorchScript compatible type", "" 740 ): 741 torch.jit.script(TestPyAssignError(PyClass())) 742 # TODO test: interface-interface class-interface inheritance errors, 743 # NamedTuple inheritance errors 744 745 def test_overloaded_fn(self): 746 @torch.jit.script 747 class Foo: 748 def __init__(self, x): 749 self.x = x 750 751 def __len__(self) -> int: 752 return len(self.x) 753 754 def __neg__(self): 755 self.x = -self.x 756 return self 757 758 def __mul__(self, other: torch.Tensor) -> torch.Tensor: 759 return self.x * other 760 761 def test_overload(): 762 a = Foo(torch.ones([3, 3])) 763 return len(a), -a * torch.zeros([3, 3]) 764 765 make_global(Foo) # see [local resolution in python] 766 767 self.checkScript(test_overload, ()) 768 # unary ops tested above 769 770 # TODO - support compiling classes from strings in jit.CompilationUnit 771 @torch.jit.script 772 class MyClass: 773 def __init__(self, x: int) -> None: 774 self.x = x 775 776 def __add__(self, other: int) -> int: 777 return self.x + other 778 779 def __sub__(self, other: int) -> int: 780 return self.x - other 781 782 def __mul__(self, other: int) -> int: 783 return self.x * other 784 785 def __pow__(self, other: int) -> int: 786 return int(self.x**other) 787 788 def __truediv__(self, other: int) -> float: 789 return self.x / other 790 791 def __mod__(self, other: int) -> int: 792 return self.x % other 793 794 def __ne__(self, other: int) -> bool: 795 return self.x != other 796 797 def __eq__(self, other: int) -> bool: 798 return self.x == other 799 800 def __lt__(self, other: int) -> bool: 801 return self.x < other 802 803 def __gt__(self, other: int) -> bool: 804 return self.x > other 805 806 def __le__(self, other: int) -> bool: 807 return self.x <= other 808 809 def __ge__(self, other: int) -> bool: 810 return self.x >= other 811 812 def __and__(self, other: int) -> int: 813 return self.x & other 814 815 def __or__(self, other: int) -> int: 816 return self.x | other 817 818 def __xor__(self, other: int) -> int: 819 return self.x ^ other 820 821 def __getitem__(self, other: int) -> int: 822 return other + 1 823 824 def __setitem__(self, idx: int, val: int) -> None: 825 self.x = val * idx 826 827 def __call__(self, val: int) -> int: 828 return self.x * val * 3 829 830 make_global(Foo) # see [local resolution in python] 831 832 def add(): 833 return MyClass(4) + 3 834 835 def sub(): # noqa: E306 836 return MyClass(4) - 3 837 838 def mul(): # noqa: E306 839 return MyClass(4) * 3 840 841 def pow(): # noqa: E306 842 return MyClass(4) ** 3 843 844 def truediv(): # noqa: E306 845 return MyClass(4) / 3 846 847 def ne(): # noqa: E306 848 return MyClass(4) != 3 849 850 def eq(): # noqa: E306 851 return MyClass(4) == 3 852 853 def lt(): # noqa: E306 854 return MyClass(4) < 3 855 856 def gt(): # noqa: E306 857 return MyClass(4) > 3 858 859 def le(): # noqa: E306 860 return MyClass(4) <= 3 861 862 def ge(): # noqa: E306 863 return MyClass(4) >= 3 864 865 def _and(): # noqa: E306 866 return MyClass(4) & 3 867 868 def _or(): # noqa: E306 869 return MyClass(4) | 3 870 871 def _xor(): # noqa: E306 872 return MyClass(4) ^ 3 873 874 def getitem(): # noqa: E306 875 return MyClass(4)[1] 876 877 def setitem(): # noqa: E306 878 a = MyClass(4) 879 a[1] = 5 880 return a.x 881 882 def call(): # noqa: E306 883 a = MyClass(5) 884 return a(2) 885 886 ops = [ 887 add, 888 sub, 889 mul, 890 pow, 891 ne, 892 eq, 893 lt, 894 gt, 895 le, 896 ge, 897 _and, 898 _or, 899 _xor, 900 getitem, 901 setitem, 902 call, 903 ] 904 905 ops.append(truediv) 906 for func in ops: 907 self.checkScript(func, ()) 908 909 with self.assertRaisesRegexWithHighlight( 910 RuntimeError, "object has no attribute or method", "" 911 ): 912 913 @torch.jit.script 914 def test(): 915 return Foo(torch.tensor(1)) + Foo(torch.tensor(1)) 916 917 def test_cast_overloads(self): 918 @torch.jit.script 919 class Foo: 920 def __init__(self, val: float) -> None: 921 self.val = val 922 923 def __int__(self): 924 return int(self.val) 925 926 def __float__(self): 927 return self.val 928 929 def __bool__(self): 930 return bool(self.val) 931 932 def __str__(self): 933 return str(self.val) 934 935 make_global(Foo) # see [local resolution in python] 936 937 def test(foo: Foo) -> Tuple[int, float, bool]: 938 if foo: 939 pass 940 return int(foo), float(foo), bool(foo) 941 942 fn = torch.jit.script(test) 943 self.assertEqual(fn(Foo(0.5)), test(0.5)) 944 self.assertEqual(fn(Foo(0.0)), test(0.0)) 945 # str has slightly different formatting 946 self.assertTrue("0.5" in (str(Foo(0.5)))) 947 self.assertTrue("0." in (str(Foo(0.0)))) 948 949 @torch.jit.script 950 class BadBool: 951 def __init__(self) -> None: 952 pass 953 954 def __bool__(self): 955 return (1, 2) 956 957 with self.assertRaisesRegexWithHighlight( 958 RuntimeError, "expected a bool expression for condition", "" 959 ): 960 961 @torch.jit.script 962 def test(): 963 if BadBool(): 964 print(1) 965 966 def test_init_compiled_first(self): 967 @torch.jit.script # noqa: B903 968 class Foo: # noqa: B903 969 def __before_init__(self): 970 # accessing this field should not throw, since __init__ should be compiled 971 return self.x 972 973 def __init__(self, x, y): 974 self.x = x 975 self.y = y 976 977 def test_class_constructs_itself(self): 978 @torch.jit.script # noqa: B903 979 class LSTMStateStack: # noqa: B903 980 def __init__(self, num_layers: int, hidden_size: int) -> None: 981 self.num_layers = num_layers 982 self.hidden_size = hidden_size 983 self.last_state = ( 984 torch.zeros(num_layers, 1, hidden_size), 985 torch.zeros(num_layers, 1, hidden_size), 986 ) 987 self.stack = [(self.last_state[0][-1], self.last_state[0][-1])] 988 989 def copy(self): 990 # should be able to construct a class inside its own methods 991 other = LSTMStateStack(self.num_layers, self.hidden_size) 992 other.stack = list(self.stack) 993 return other 994 995 def test_optional_type_promotion(self): 996 @torch.jit.script 997 class Leaf: 998 def __init__(self) -> None: 999 self.x = 1 1000 1001 # should not throw 1002 @torch.jit.script # noqa: B903 1003 class Tree: # noqa: B903 1004 def __init__(self) -> None: 1005 self.child = torch.jit.annotate(Optional[Leaf], None) 1006 1007 def add_child(self, child: Leaf) -> None: 1008 self.child = child 1009 1010 def test_recursive_class(self): 1011 """ 1012 Recursive class types not yet supported. We should give a good error message. 1013 """ 1014 with self.assertRaises(RuntimeError): 1015 1016 @torch.jit.script # noqa: B903 1017 class Tree: # noqa: B903 1018 def __init__(self) -> None: 1019 self.parent = torch.jit.annotate(Optional[Tree], None) 1020 1021 def test_class_constant(self): 1022 class M(torch.nn.Module): 1023 __constants__ = ["w"] 1024 1025 def __init__(self, w): 1026 super().__init__() 1027 self.w = w 1028 1029 def forward(self, x): 1030 # Make sure class constant is accessible in method 1031 y = self.w 1032 return x, y 1033 1034 # Test serialization/deserialization of class constant 1035 for c in (2, 1.0, None, True, "str", (2, 3), [5.9, 7.3]): 1036 m = torch.jit.script(M(c)) 1037 buffer = io.BytesIO() 1038 torch.jit.save(m, buffer) 1039 1040 buffer.seek(0) 1041 m_loaded = torch.jit.load(buffer) 1042 input = torch.rand(2, 3) 1043 self.assertEqual(m(input), m_loaded(input)) 1044 # Make sure class constant is accessible from module 1045 self.assertEqual(m.w, m_loaded.w) 1046 1047 def test_py_class_to_ivalue_missing_attribute(self): 1048 class Foo: 1049 i: int 1050 f: float 1051 1052 def __init__(self, i: int, f: float): 1053 self.i = i 1054 self.f = f 1055 1056 make_global(Foo) # see [local resolution in python] 1057 1058 @torch.jit.script 1059 def test_fn(x: Foo) -> float: 1060 return x.i + x.f 1061 1062 test_fn(Foo(3, 4.0)) 1063 1064 with self.assertRaisesRegexWithHighlight( 1065 RuntimeError, "missing attribute i", "" 1066 ): 1067 test_fn(torch.rand(3, 4)) 1068 1069 def test_unused_method(self): 1070 """ 1071 Test unused methods on scripted classes. 1072 """ 1073 1074 @torch.jit.script 1075 class Unused: 1076 def __init__(self) -> None: 1077 self.count: int = 0 1078 self.items: List[int] = [] 1079 1080 def used(self): 1081 self.count += 1 1082 return self.count 1083 1084 @torch.jit.unused 1085 def unused(self, x: int, y: Iterable[int], **kwargs) -> int: 1086 a = next(self.items) 1087 return a 1088 1089 def uses_unused(self) -> int: 1090 return self.unused(y="hi", x=3) 1091 1092 class ModuleWithUnused(nn.Module): 1093 def __init__(self) -> None: 1094 super().__init__() 1095 self.obj = Unused() 1096 1097 def forward(self): 1098 return self.obj.used() 1099 1100 @torch.jit.export 1101 def calls_unused(self): 1102 return self.obj.unused(3, "hi") 1103 1104 @torch.jit.export 1105 def calls_unused_indirectly(self): 1106 return self.obj.uses_unused() 1107 1108 python_module = ModuleWithUnused() 1109 script_module = torch.jit.script(ModuleWithUnused()) 1110 1111 # Forward should work because it does not used any methods marked unused. 1112 self.assertEqual(python_module.forward(), script_module.forward()) 1113 1114 # Calling a method marked unused should throw. 1115 with self.assertRaises(torch.jit.Error): 1116 script_module.calls_unused() 1117 1118 with self.assertRaises(torch.jit.Error): 1119 script_module.calls_unused_indirectly() 1120 1121 def test_self_referential_method(self): 1122 """ 1123 Test that a scripted class can have a method that refers to the class itself 1124 in its type annotations. 1125 """ 1126 1127 @torch.jit.script 1128 class Meta: 1129 def __init__(self, a: int): 1130 self.a = a 1131 1132 def method(self, other: List["Meta"]) -> "Meta": 1133 return Meta(len(other)) 1134 1135 class ModuleWithMeta(torch.nn.Module): 1136 def __init__(self, a: int): 1137 super().__init__() 1138 self.meta = Meta(a) 1139 1140 def forward(self): 1141 new_obj = self.meta.method([self.meta]) 1142 return new_obj.a 1143 1144 self.checkModule(ModuleWithMeta(5), ()) 1145 1146 def test_type_annotation(self): 1147 """ 1148 Test that annotating container attributes with types works correctly 1149 """ 1150 1151 @torch.jit.script 1152 class CompetitiveLinkingTokenReplacementUtils: 1153 def __init__(self) -> None: 1154 self.my_list: List[Tuple[float, int, int]] = [] 1155 self.my_dict: Dict[int, int] = {} 1156 1157 @torch.jit.script 1158 def foo(): 1159 y = CompetitiveLinkingTokenReplacementUtils() 1160 new_dict: Dict[int, int] = {1: 1, 2: 2} 1161 y.my_dict = new_dict 1162 1163 new_list: List[Tuple[float, int, int]] = [(1.0, 1, 1)] 1164 y.my_list = new_list 1165 return y 1166 1167 def test_default_args(self): 1168 """ 1169 Test that methods on class types can have default arguments. 1170 """ 1171 1172 @torch.jit.script 1173 class ClassWithDefaultArgs: 1174 def __init__( 1175 self, 1176 a: int = 1, 1177 b: Optional[List[int]] = None, 1178 c: Tuple[int, int, int] = (1, 2, 3), 1179 d: Optional[Dict[int, int]] = None, 1180 e: Optional[str] = None, 1181 ): 1182 self.int = a 1183 self.tup = c 1184 self.str = e 1185 1186 self.list = [1, 2, 3] 1187 if b is not None: 1188 self.list = b 1189 1190 self.dict = {1: 2, 3: 4} 1191 if d is not None: 1192 self.dict = d 1193 1194 def add(self, b: int, scale: float = 1.0) -> float: 1195 return self.int * scale + b 1196 1197 def all_defaults() -> int: 1198 obj: ClassWithDefaultArgs = ClassWithDefaultArgs() 1199 return obj.int + obj.list[2] + obj.tup[1] 1200 1201 def some_defaults() -> int: 1202 obj: ClassWithDefaultArgs = ClassWithDefaultArgs(b=[5, 6, 7]) 1203 return obj.int + obj.list[2] + obj.dict[1] 1204 1205 def override_defaults() -> int: 1206 obj: ClassWithDefaultArgs = ClassWithDefaultArgs( 1207 3, [9, 10, 11], (12, 13, 14), {3: 4}, "str" 1208 ) 1209 s: int = obj.int 1210 1211 for x in obj.list: 1212 s += x 1213 1214 for y in obj.tup: 1215 s += y 1216 1217 s += obj.dict[3] 1218 1219 st = obj.str 1220 if st is not None: 1221 s += len(st) 1222 1223 return s 1224 1225 def method_defaults() -> float: 1226 obj: ClassWithDefaultArgs = ClassWithDefaultArgs() 1227 return obj.add(3) + obj.add(3, 0.25) 1228 1229 self.checkScript(all_defaults, ()) 1230 self.checkScript(some_defaults, ()) 1231 self.checkScript(override_defaults, ()) 1232 self.checkScript(method_defaults, ()) 1233 1234 # The constructor of this class below has some arguments without default values. 1235 class ClassWithSomeDefaultArgs: # noqa: B903 1236 def __init__( 1237 self, 1238 a: int, 1239 b: int = 1, 1240 ): 1241 self.a = a 1242 self.b = b 1243 1244 def default_b() -> int: 1245 obj: ClassWithSomeDefaultArgs = ClassWithSomeDefaultArgs(1) 1246 return obj.a + obj.b 1247 1248 def set_b() -> int: 1249 obj: ClassWithSomeDefaultArgs = ClassWithSomeDefaultArgs(1, 4) 1250 return obj.a + obj.b 1251 1252 self.checkScript(default_b, ()) 1253 self.checkScript(set_b, ()) 1254 1255 # The constructor of this class below has mutable arguments. This should throw 1256 # an error. 1257 class ClassWithMutableArgs: # noqa: B903 1258 def __init__( 1259 self, 1260 a: List[int] = [1, 2, 3], # noqa: B006 1261 ): 1262 self.a = a 1263 1264 def should_fail(): 1265 obj: ClassWithMutableArgs = ClassWithMutableArgs() 1266 1267 with self.assertRaisesRegexWithHighlight( 1268 RuntimeError, "Mutable default parameters are not supported", "" 1269 ): 1270 torch.jit.script(should_fail) 1271 1272 def test_staticmethod(self): 1273 """ 1274 Test static methods on class types. 1275 """ 1276 1277 @torch.jit.script 1278 class ClassWithStaticMethod: 1279 def __init__(self, a: int, b: int): 1280 self.a: int = a 1281 self.b: int = b 1282 1283 def get_a(self): 1284 return self.a 1285 1286 def get_b(self): 1287 return self.b 1288 1289 def __eq__(self, other: "ClassWithStaticMethod"): 1290 return self.a == other.a and self.b == other.b 1291 1292 # staticmethod that calls constructor. 1293 @staticmethod 1294 def create(args: List["ClassWithStaticMethod"]) -> "ClassWithStaticMethod": 1295 return ClassWithStaticMethod(args[0].a, args[0].b) 1296 1297 # staticmethod that calls another staticmethod. 1298 @staticmethod 1299 def create_from(a: int, b: int) -> "ClassWithStaticMethod": 1300 a = ClassWithStaticMethod(a, b) 1301 return ClassWithStaticMethod.create([a]) 1302 1303 # Script function that calls staticmethod. 1304 def test_function(a: int, b: int) -> "ClassWithStaticMethod": 1305 return ClassWithStaticMethod.create_from(a, b) 1306 1307 make_global(ClassWithStaticMethod) 1308 1309 self.checkScript(test_function, (1, 2)) 1310 1311 def test_classmethod(self): 1312 """ 1313 Test classmethods on class types. 1314 """ 1315 1316 @torch.jit.script 1317 class ClassWithClassMethod: 1318 def __init__(self, a: int): 1319 self.a: int = a 1320 1321 def __eq__(self, other: "ClassWithClassMethod"): 1322 return self.a == other.a 1323 1324 @classmethod 1325 def create(cls, a: int) -> "ClassWithClassMethod": 1326 return cls(a) 1327 1328 make_global(ClassWithClassMethod) 1329 1330 def test_function(a: int) -> "ClassWithClassMethod": 1331 x = ClassWithClassMethod(a) 1332 # Support calling classmethod with an instance 1333 # Calling with the class is not supported. 1334 return x.create(a) 1335 1336 self.checkScript(test_function, (1,)) 1337 1338 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 1339 def test_properties(self): 1340 """ 1341 Test that a scripted class can make use of the @property decorator. 1342 """ 1343 1344 def free_function(x: int) -> int: 1345 return x + 1 1346 1347 @torch.jit.script 1348 class Properties: 1349 __jit_unused_properties__ = ["unsupported"] 1350 1351 def __init__(self, a: int): 1352 self.a = a 1353 1354 @property 1355 def attr(self) -> int: 1356 return self.a - 1 1357 1358 @property 1359 def unsupported(self) -> int: 1360 return sum([self.a]) 1361 1362 @torch.jit.unused 1363 @property 1364 def unsupported_2(self) -> int: 1365 return sum([self.a]) 1366 1367 @unsupported_2.setter 1368 def unsupported_2(self, value): 1369 self.a = sum([self.a]) 1370 1371 @attr.setter 1372 def attr(self, value: int): 1373 self.a = value + 3 1374 1375 @torch.jit.script 1376 class NoSetter: 1377 def __init__(self, a: int): 1378 self.a = a 1379 1380 @property 1381 def attr(self) -> int: 1382 return free_function(self.a) 1383 1384 @torch.jit.script 1385 class MethodThatUsesProperty: 1386 def __init__(self, a: int): 1387 self.a = a 1388 1389 @property 1390 def attr(self) -> int: 1391 return self.a - 2 1392 1393 @attr.setter 1394 def attr(self, value: int): 1395 self.a = value + 4 1396 1397 def forward(self): 1398 return self.attr 1399 1400 class ModuleWithProperties(torch.nn.Module): 1401 def __init__(self, a: int): 1402 super().__init__() 1403 self.props = Properties(a) 1404 1405 def forward(self, a: int, b: int, c: int, d: int): 1406 self.props.attr = a 1407 props = Properties(b) 1408 no_setter = NoSetter(c) 1409 method_uses_property = MethodThatUsesProperty(a + b) 1410 1411 props.attr = c 1412 method_uses_property.attr = d 1413 1414 return self.props.attr + no_setter.attr + method_uses_property.forward() 1415 1416 self.checkModule( 1417 ModuleWithProperties(5), 1418 ( 1419 5, 1420 6, 1421 7, 1422 8, 1423 ), 1424 ) 1425 1426 def test_custom_delete(self): 1427 """ 1428 Test that del can be called on an instance of a class that 1429 overrides __delitem__. 1430 """ 1431 1432 class Example: 1433 def __init__(self) -> None: 1434 self._data: Dict[str, torch.Tensor] = {"1": torch.tensor(1.0)} 1435 1436 def check(self, key: str) -> bool: 1437 return key in self._data 1438 1439 def __delitem__(self, key: str): 1440 del self._data[key] 1441 1442 def fn() -> bool: 1443 example = Example() 1444 del example["1"] 1445 return example.check("1") 1446 1447 self.checkScript(fn, ()) 1448 1449 # Test the case in which the class does not have __delitem__ defined. 1450 class NoDelItem: 1451 def __init__(self) -> None: 1452 self._data: Dict[str, torch.Tensor] = {"1": torch.tensor(1.0)} 1453 1454 def check(self, key: str) -> bool: 1455 return key in self._data 1456 1457 def fn() -> bool: 1458 example = NoDelItem() 1459 key = "1" 1460 del example[key] 1461 return example.check(key) 1462 1463 with self.assertRaisesRegexWithHighlight( 1464 RuntimeError, r"Class does not define __delitem__", "example[key]" 1465 ): 1466 self.checkScript(fn, ()) 1467 1468 def test_recursive_script_builtin_type_resolution(self): 1469 """ 1470 Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled. 1471 """ 1472 # A will be implicitly compiled because it is not annotated with @torch.jit.script 1473 # but is used in g() below. 1474 tensor_t = torch.Tensor 1475 device_t = torch.device 1476 device_ty = torch.device 1477 1478 class A: 1479 def __init__(self) -> None: 1480 pass 1481 1482 def f(self, x: tensor_t, y: torch.device) -> tensor_t: 1483 return x.to(device=y) 1484 1485 def g(self, x: device_t) -> device_ty: 1486 return x 1487 1488 def h(self, a: "A") -> "A": 1489 return A() 1490 1491 def i(self, a: List[int]) -> int: 1492 return a[0] 1493 1494 def j(self, l: List[device_t]) -> device_ty: 1495 return l[0] 1496 1497 def call_f(): 1498 a = A() 1499 return a.f(torch.tensor([1]), torch.device("cpu")) 1500 1501 def call_g(): 1502 a = A() 1503 return a.g(torch.device("cpu")) 1504 1505 def call_i(): 1506 a = A() 1507 return a.i([3]) 1508 1509 def call_j(): 1510 a = A() 1511 return a.j([torch.device("cpu"), torch.device("cpu")]) 1512 1513 for fn in [call_f, call_g, call_i, call_j]: 1514 self.checkScript(fn, ()) 1515 s = self.getExportImportCopy(torch.jit.script(fn)) 1516 self.assertEqual(s(), fn()) 1517 1518 def test_recursive_script_module_builtin_type_resolution(self): 1519 """ 1520 Test resolution of built-in torch types(e.g. torch.Tensor, torch.device) when a class is recursively compiled 1521 when compiling a module. 1522 """ 1523 1524 class Wrapper: 1525 def __init__(self, t): 1526 self.t = t 1527 1528 def to(self, l: List[torch.device], device: Optional[torch.device] = None): 1529 return self.t.to(device=device) 1530 1531 class A(nn.Module): 1532 def forward(self): 1533 return Wrapper(torch.rand(4, 4)) 1534 1535 scripted = torch.jit.script(A()) 1536 self.getExportImportCopy(scripted) 1537 1538 def test_class_attribute_wrong_type(self): 1539 """ 1540 Test that the error message displayed when convering a class type 1541 to an IValue that has an attribute of the wrong type. 1542 """ 1543 1544 @torch.jit.script # noqa: B903 1545 class ValHolder: # noqa: B903 1546 def __init__(self, val): 1547 self.val = val 1548 1549 class Mod(nn.Module): 1550 def __init__(self) -> None: 1551 super().__init__() 1552 self.mod1 = ValHolder("1") 1553 self.mod2 = ValHolder("2") 1554 1555 def forward(self, cond: bool): 1556 if cond: 1557 mod = self.mod1 1558 else: 1559 mod = self.mod2 1560 return mod.val 1561 1562 with self.assertRaisesRegexWithHighlight( 1563 RuntimeError, "Could not cast attribute 'val' to type Tensor", "" 1564 ): 1565 torch.jit.script(Mod()) 1566 1567 def test_recursive_scripting(self): 1568 """ 1569 Test that class types are recursively scripted when an Python instance of one 1570 is encountered as a module attribute. 1571 """ 1572 1573 class Class: 1574 def __init__(self, a: int): 1575 self.a = a 1576 1577 def get_a(self) -> int: 1578 return self.a 1579 1580 class M(torch.nn.Module): 1581 def __init__(self, obj): 1582 super().__init__() 1583 self.obj = obj 1584 1585 def forward(self) -> int: 1586 return self.obj.get_a() 1587 1588 self.checkModule(M(Class(4)), ()) 1589 1590 def test_recursive_scripting_failed(self): 1591 """ 1592 Test that class types module attributes that fail to script 1593 are added as failed attributes and do not cause compilation itself 1594 to fail unless they are used in scripted code. 1595 """ 1596 1597 class UnscriptableClass: 1598 def __init__(self, a: int): 1599 self.a = a 1600 1601 def get_a(self) -> bool: 1602 return issubclass(self.a, int) 1603 1604 # This Module has an attribute of type UnscriptableClass 1605 # and tries to use it in scripted code. This should fail. 1606 class ShouldNotCompile(torch.nn.Module): 1607 def __init__(self, obj): 1608 super().__init__() 1609 self.obj = obj 1610 1611 def forward(self) -> bool: 1612 return self.obj.get_a() 1613 1614 with self.assertRaisesRegexWithHighlight( 1615 RuntimeError, "failed to convert Python type", "" 1616 ): 1617 torch.jit.script(ShouldNotCompile(UnscriptableClass(4))) 1618 1619 # This Module has an attribute of type UnscriptableClass 1620 # and does not try to use it in scripted code. This should not fail. 1621 class ShouldCompile(torch.nn.Module): 1622 def __init__(self, obj): 1623 super().__init__() 1624 self.obj = obj 1625 1626 @torch.jit.ignore 1627 def ignored_method(self) -> bool: 1628 return self.obj.get_a() 1629 1630 def forward(self, x: int) -> int: 1631 return x + x 1632 1633 self.checkModule(ShouldCompile(UnscriptableClass(4)), (4,)) 1634 1635 def test_unresolved_class_attributes(self): 1636 class UnresolvedAttrClass: 1637 def __init__(self) -> None: 1638 pass 1639 1640 (attr_a, attr_b), [attr_c, attr_d] = ("", ""), ["", ""] 1641 attr_e: int = 0 1642 1643 def fn_a(): 1644 u = UnresolvedAttrClass() 1645 return u.attr_a 1646 1647 def fn_b(): 1648 u = UnresolvedAttrClass() 1649 return u.attr_b 1650 1651 def fn_c(): 1652 u = UnresolvedAttrClass() 1653 return u.attr_c 1654 1655 def fn_d(): 1656 u = UnresolvedAttrClass() 1657 return u.attr_d 1658 1659 def fn_e(): 1660 u = UnresolvedAttrClass() 1661 return u.attr_e 1662 1663 error_message_regex = ( 1664 "object has no attribute or method.*is defined as a class attribute" 1665 ) 1666 for fn in (fn_a, fn_b, fn_c, fn_d, fn_e): 1667 with self.assertRaisesRegex(RuntimeError, error_message_regex): 1668 torch.jit.script(fn) 1669