1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from typing import Any, Dict, List, NamedTuple, Optional, Tuple # noqa: F401 6 7import torch 8from torch.jit._monkeytype_config import _IS_MONKEYTYPE_INSTALLED 9from torch.testing._internal.common_utils import NoTest 10from torch.testing._internal.jit_utils import JitTestCase, make_global 11 12 13# Make the helper files in test/ importable 14pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 15sys.path.append(pytorch_test_dir) 16 17if not _IS_MONKEYTYPE_INSTALLED: 18 print( 19 "monkeytype is not installed. Skipping tests for Profile-Directed Typing", 20 file=sys.stderr, 21 ) 22 JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 23 24if __name__ == "__main__": 25 raise RuntimeError( 26 "This test file is not meant to be run directly, use:\n\n" 27 "\tpython test/test_jit.py TESTNAME\n\n" 28 "instead." 29 ) 30 31 32class TestPDT(JitTestCase): 33 """ 34 A suite of tests for profile directed typing in TorchScript. 35 """ 36 37 def test_nn_module(self): 38 class TestPDTModel(torch.nn.Module): 39 def forward(self, x) -> Any: 40 if isinstance(x, int): 41 return x + 1 42 elif isinstance(x, float): 43 return x - 1 44 else: 45 return x 46 47 make_global(TestPDTModel) 48 pdt_model = TestPDTModel() 49 inp: List[Tuple[Any, ...]] = [ 50 (20,), 51 (2.7,), 52 (False,), 53 ] 54 scripted_pdt_model = torch.jit.script( 55 pdt_model, example_inputs={pdt_model: inp} 56 ) 57 self.assertEqual(scripted_pdt_model(50), pdt_model(50)) 58 self.assertEqual(scripted_pdt_model(1.8), pdt_model(1.8)) 59 self.assertTrue(scripted_pdt_model(True), pdt_model(True)) 60 61 def test_nested_nn_module_class(self): 62 class NestedPDTInner(torch.nn.Module): 63 def forward(self, x): 64 if isinstance(x, int): 65 return x * 10 66 return x 67 68 class NestedModulePDTWrapper(torch.nn.Module): 69 def __init__(self, inner): 70 super().__init__() 71 self.inner = inner 72 73 def forward(self, x): 74 return self.inner(x) 75 76 make_global(NestedPDTInner, NestedModulePDTWrapper) 77 inner_pdt_model = NestedPDTInner() 78 wrapped_pdt_model = NestedModulePDTWrapper(inner_pdt_model) 79 inp: List[Tuple[Any, ...]] = [(20,), (False,)] 80 scripted_pdt_model = torch.jit.script( 81 wrapped_pdt_model, example_inputs={wrapped_pdt_model: inp} 82 ) 83 self.assertEqual(scripted_pdt_model(30), wrapped_pdt_model(30)) 84 self.assertEqual(scripted_pdt_model(1.9), wrapped_pdt_model(1.9)) 85 self.assertTrue(scripted_pdt_model(True), wrapped_pdt_model(True)) 86 87 def test_nested_nn_module_class_with_args(self): 88 class NestedModulePDTInner(torch.nn.Module): 89 def forward(self, x, y): 90 if isinstance(x, int): 91 return x * 10 + y 92 return x 93 94 class NestedModulePDTOuter(torch.nn.Module): 95 def __init__(self, inner): 96 super().__init__() 97 self.inner = inner 98 99 def forward(self, x): 100 return self.inner(x, 20) 101 102 make_global(NestedModulePDTInner, NestedModulePDTOuter) 103 inner_pdt_model = NestedModulePDTInner() 104 outer_pdt_model = NestedModulePDTOuter(inner_pdt_model) 105 inner_input: List[Tuple[Any, ...]] = [ 106 (10, 10), 107 (1.9, 20), 108 ] 109 outer_input: List[Tuple[Any, ...]] = [(20,), (False,)] 110 scripted_pdt_model = torch.jit.script( 111 outer_pdt_model, 112 example_inputs={ 113 inner_pdt_model: inner_input, 114 outer_pdt_model: outer_input, 115 }, 116 ) 117 self.assertEqual(scripted_pdt_model(30), outer_pdt_model(30)) 118 self.assertEqual(scripted_pdt_model(1.9), outer_pdt_model(1.9)) 119 self.assertTrue(scripted_pdt_model(True), outer_pdt_model(True)) 120 121 def test_nested_function_in_forward(self): 122 class NestedFunctionInForward(torch.nn.Module): 123 def forward(self, x): 124 return self.fun(x) + 10 125 126 def fun(self, x): 127 if isinstance(x, bool): 128 return 0 129 elif isinstance(x, int): 130 return x + 1 131 return 0 132 133 make_global(NestedFunctionInForward) 134 pdt_model = NestedFunctionInForward() 135 inp: List[Tuple[Any, ...]] = [(-1,), (False,)] 136 scripted_pdt_model = torch.jit.script( 137 pdt_model, example_inputs={pdt_model: inp} 138 ) 139 self.assertEqual(scripted_pdt_model(30), pdt_model(30)) 140 self.assertEqual(scripted_pdt_model(True), pdt_model(True)) 141 142 def test_nn_module_with_export_function(self): 143 class TestModelWithExport(torch.nn.Module): 144 @torch.jit.export 145 def fn(self, x, y) -> Any: 146 assert not (isinstance(x, bool) and isinstance(y, bool)) 147 if isinstance(x, int) and isinstance(y, int): 148 return x + y 149 elif isinstance(x, float) and isinstance(y, float): 150 return x - y 151 else: 152 return -1 153 154 make_global(TestModelWithExport) 155 pdt_model = TestModelWithExport() 156 inp: List[Tuple[Any, ...]] = [ 157 ( 158 20, 159 10, 160 ), 161 ( 162 2.7, 163 8.9, 164 ), 165 ] 166 scripted_pdt_model = torch.jit.script( 167 pdt_model, example_inputs={pdt_model.fn: inp} 168 ) 169 self.assertEqual(scripted_pdt_model.fn(10, 90), pdt_model.fn(10, 90)) 170 self.assertEqual(scripted_pdt_model.fn(1.8, 2.2), pdt_model.fn(1.8, 2.2)) 171 self.assertTrue( 172 scripted_pdt_model.fn(torch.ones(1), 2), pdt_model.fn(torch.ones(1), 2) 173 ) 174 175 def test_class_methods(self): 176 class PDTModel: 177 def test_sum(self, a): 178 return sum(a) 179 180 make_global(PDTModel) 181 pdt_model = PDTModel() 182 inp: List[Tuple[Any, ...]] = [ 183 ( 184 [ 185 10, 186 20, 187 ], 188 ), 189 ] 190 scripted_pdt_model = torch.jit.script( 191 PDTModel, example_inputs={pdt_model.test_sum: inp} 192 ) 193 script_model = scripted_pdt_model() 194 self.assertEqual( 195 script_model.test_sum( 196 [ 197 10, 198 20, 199 30, 200 ], 201 ), 202 pdt_model.test_sum( 203 [ 204 10, 205 20, 206 30, 207 ], 208 ), 209 ) 210 211 def test_class_with_multiple_methods(self): 212 class PDTModelWithManyMethods: 213 def test_list_to_dict(self, a): 214 new_dictionary: Dict[float, bool] = {} 215 for element in a: 216 new_dictionary[element] = True 217 return new_dictionary 218 219 def test_substring(self, a, b): 220 return b in a 221 222 make_global(PDTModelWithManyMethods) 223 pdt_model = PDTModelWithManyMethods() 224 list_inp: List[Tuple[Any, ...]] = [ 225 ( 226 [ 227 1.2, 228 2.3, 229 ], 230 ), 231 ] 232 str_inp: List[Tuple[Any, ...]] = [ 233 ( 234 "abc", 235 "b", 236 ), 237 ] 238 scripted_pdt_model = torch.jit.script( 239 PDTModelWithManyMethods, 240 example_inputs={ 241 pdt_model.test_list_to_dict: list_inp, 242 pdt_model.test_substring: str_inp, 243 }, 244 ) 245 script_model = scripted_pdt_model() 246 self.assertEqual( 247 script_model.test_list_to_dict( 248 [ 249 1.1, 250 2.2, 251 3.3, 252 ], 253 ), 254 pdt_model.test_list_to_dict( 255 [ 256 1.1, 257 2.2, 258 3.3, 259 ], 260 ), 261 ) 262 self.assertEqual( 263 script_model.test_substring( 264 "helloworld", 265 "world", 266 ), 267 pdt_model.test_substring( 268 "helloworld", 269 "world", 270 ), 271 ) 272 self.assertEqual( 273 script_model.test_substring( 274 "helloworld", 275 "def", 276 ), 277 pdt_model.test_substring( 278 "helloworld", 279 "def", 280 ), 281 ) 282 283 def test_multiple_class_with_same_method(self): 284 class PDTModelOne: 285 def test_find(self, a, b): 286 return b in a.keys() 287 288 class PDTModelTwo: 289 def test_find(self, a, b): 290 return b in a 291 292 make_global(PDTModelOne, PDTModelTwo) 293 pdt_model_one = PDTModelOne() 294 pdt_model_two = PDTModelTwo() 295 dict_inp: List[Tuple[Any, ...]] = [ 296 ( 297 { 298 1.2: True, 299 2.3: False, 300 }, 301 1.2, 302 ), 303 ] 304 list_inp: List[Tuple[Any, ...]] = [ 305 ( 306 [ 307 "abc", 308 "b", 309 ], 310 "c", 311 ), 312 ] 313 scripted_pdt_model_one = torch.jit.script( 314 PDTModelOne, example_inputs={pdt_model_one.test_find: dict_inp} 315 ) 316 scripted_pdt_model_two = torch.jit.script( 317 PDTModelTwo, example_inputs={pdt_model_two.test_find: list_inp} 318 ) 319 320 script_model_one, script_model_two = ( 321 scripted_pdt_model_one(), 322 scripted_pdt_model_two(), 323 ) 324 self.assertEqual( 325 script_model_one.test_find( 326 { 327 1.1: True, 328 2.2: True, 329 3.3: False, 330 }, 331 4.4, 332 ), 333 pdt_model_one.test_find( 334 { 335 1.1: True, 336 2.2: True, 337 3.3: False, 338 }, 339 4.4, 340 ), 341 ) 342 self.assertEqual( 343 script_model_two.test_find( 344 [ 345 "hello", 346 "world", 347 ], 348 "world", 349 ), 350 pdt_model_two.test_find( 351 [ 352 "hello", 353 "world", 354 ], 355 "world", 356 ), 357 ) 358 359 def test_pdt(self): 360 def test_sum(a, b): 361 return a + b 362 363 make_global(test_sum) 364 scripted_fn_add = torch.jit.script(test_sum, example_inputs=[(3, 4)]) 365 self.assertEqual(scripted_fn_add(10, 2), test_sum(10, 2)) 366 367 def test_sub(a, b): 368 return a - b 369 370 make_global(test_sub) 371 scripted_fn_sub = torch.jit.script(test_sub, example_inputs=[(3.9, 4.10)]) 372 self.assertEqual(scripted_fn_sub(6.5, 2.9), test_sub(6.5, 2.9)) 373 374 def test_mul(a, b): 375 return a * b 376 377 make_global(test_mul) 378 scripted_fn_mul = torch.jit.script(test_mul, example_inputs=[(-10, 9)]) 379 self.assertEqual(scripted_fn_mul(-1, 3), test_mul(-1, 3)) 380 381 def test_args_complex(real, img): 382 return torch.complex(real, img) 383 384 make_global(test_args_complex) 385 scripted_fn_complex = torch.jit.script( 386 test_args_complex, example_inputs=[(torch.rand(3, 4), torch.rand(3, 4))] 387 ) 388 arg1, arg2 = torch.rand(3, 4), torch.rand(3, 4) 389 self.assertEqual(scripted_fn_complex(arg1, arg2), test_args_complex(arg1, arg2)) 390 391 def test_bool(a): 392 if a: 393 return -1 394 else: 395 return 0 396 397 make_global(test_bool) 398 scripted_fn_bool = torch.jit.script(test_bool, example_inputs=[(True,)]) 399 self.assertEqual(scripted_fn_bool(True), test_bool(True)) 400 401 def test_str(a): 402 if a == "": 403 return False 404 else: 405 return True 406 407 make_global(test_str) 408 scripted_fn_str = torch.jit.script(test_str, example_inputs=[("",)]) 409 self.assertEqual(scripted_fn_str("abc"), test_str("abc")) 410 411 def test_pdt_list_and_tuple(self): 412 def test_list_and_tuple(a): 413 return sum(a) 414 415 make_global(test_list_and_tuple) 416 417 scripted_fn_float_list_input = torch.jit.script( 418 test_list_and_tuple, example_inputs=[([4.9, 8.9],)] 419 ) 420 self.assertEqual( 421 scripted_fn_float_list_input([11.9, 7.6]), test_list_and_tuple([11.9, 7.6]) 422 ) 423 424 scripted_fn_bool_list_input = torch.jit.script( 425 test_list_and_tuple, example_inputs=[([True, False, True],)] 426 ) 427 self.assertEqual( 428 scripted_fn_bool_list_input([True, True, True]), 429 test_list_and_tuple([True, True, True]), 430 ) 431 432 scripted_fn_int_list_input = torch.jit.script( 433 test_list_and_tuple, example_inputs=[([3, 4, 5],)] 434 ) 435 self.assertEqual( 436 scripted_fn_int_list_input([1, 2, 3]), test_list_and_tuple([1, 2, 3]) 437 ) 438 439 scripted_fn_float_tuple_input = torch.jit.script( 440 test_list_and_tuple, example_inputs=[((4.9, 8.9),)] 441 ) 442 self.assertEqual( 443 scripted_fn_float_tuple_input((11.9, 7.6)), test_list_and_tuple((11.9, 7.6)) 444 ) 445 446 scripted_fn_bool_tuple_input = torch.jit.script( 447 test_list_and_tuple, example_inputs=[((True, False, True),)] 448 ) 449 self.assertEqual( 450 scripted_fn_bool_tuple_input((True, True, True)), 451 test_list_and_tuple((True, True, True)), 452 ) 453 454 scripted_fn_int_tuple_input = torch.jit.script( 455 test_list_and_tuple, example_inputs=[((3, 4, 5),)] 456 ) 457 self.assertEqual( 458 scripted_fn_int_tuple_input((1, 2, 3)), test_list_and_tuple((1, 2, 3)) 459 ) 460 461 def test_nested_list_and_tuple(self): 462 def test_nested_list(inp): 463 return [sum(v) for v in inp] 464 465 def test_nested_tuple(inp): 466 ans = 0.0 467 for tup in inp: 468 for val in tup: 469 if val > 0: 470 ans *= val 471 return ans 472 473 make_global(test_nested_list, test_nested_tuple) 474 475 list_inp = [ 476 [ 477 1, 478 2, 479 3, 480 ], 481 [ 482 5, 483 6, 484 7, 485 ], 486 ] 487 scripted_fn = torch.jit.script( 488 test_nested_list, 489 example_inputs=[ 490 (list_inp,), 491 ], 492 ) 493 inp = [ 494 [ 495 0, 496 4, 497 7, 498 ], 499 [ 500 8, 501 11, 502 ], 503 [ 504 6, 505 -1, 506 -20, 507 ], 508 ] 509 self.assertEqual( 510 scripted_fn( 511 inp, 512 ), 513 test_nested_list( 514 inp, 515 ), 516 ) 517 518 list_inp = ( 519 [ 520 1, 521 2, 522 3, 523 ], 524 [ 525 5, 526 6, 527 7, 528 ], 529 ) 530 scripted_fn = torch.jit.script( 531 test_nested_list, 532 example_inputs=[ 533 (list_inp,), 534 ], 535 ) 536 inp = ( 537 [ 538 0, 539 4, 540 7, 541 ], 542 [ 543 8, 544 11, 545 ], 546 [ 547 6, 548 -1, 549 -20, 550 ], 551 ) 552 self.assertEqual( 553 scripted_fn( 554 inp, 555 ), 556 test_nested_list( 557 inp, 558 ), 559 ) 560 561 tup_inp = [ 562 ( 563 1.0, 564 2.6, 565 3.7, 566 ), 567 ( 568 5.7, 569 6.1, 570 1.7, 571 ), 572 ] 573 scripted_fn = torch.jit.script( 574 test_nested_tuple, 575 example_inputs=[ 576 (tup_inp,), 577 ], 578 ) 579 inp = [ 580 ( 581 1.0, 582 4.1, 583 7.4, 584 ), 585 ( 586 4.8, 587 1.1, 588 -1.2, 589 ), 590 ( 591 6.3, 592 -1.3, 593 -2.0, 594 ), 595 ] 596 self.assertEqual( 597 scripted_fn( 598 inp, 599 ), 600 test_nested_tuple( 601 inp, 602 ), 603 ) 604 605 tup_inp = ( 606 ( 607 True, 608 False, 609 True, 610 ), 611 ( 612 False, 613 False, 614 False, 615 ), 616 ) 617 scripted_fn = torch.jit.script( 618 test_nested_tuple, 619 example_inputs=[ 620 (tup_inp,), 621 ], 622 ) 623 inp = ( 624 ( 625 True, 626 True, 627 True, 628 ), 629 ( 630 False, 631 False, 632 True, 633 ), 634 ) 635 self.assertEqual( 636 scripted_fn( 637 inp, 638 ), 639 test_nested_tuple( 640 inp, 641 ), 642 ) 643 644 def test_pdt_dict(self): 645 def test_dict(a): 646 return a["foo"] 647 648 def test_dict_int_list(a): 649 return a[1] 650 651 make_global(test_dict, test_dict_int_list) 652 653 str_bool_inp = {"foo": True, "bar": False} 654 scripted_fn = torch.jit.script(test_dict, example_inputs=[(str_bool_inp,)]) 655 self.assertEqual( 656 scripted_fn( 657 {"foo": False, "bar": True}, 658 ), 659 test_dict( 660 {"foo": False, "bar": True}, 661 ), 662 ) 663 664 str_list_inp = {0: [True, False], 1: [False, True]} 665 scripted_fn = torch.jit.script( 666 test_dict_int_list, example_inputs=[(str_list_inp,)] 667 ) 668 self.assertEqual( 669 scripted_fn( 670 {0: [False, False], 1: [True, True]}, 671 ), 672 test_dict_int_list( 673 {0: [False, False], 1: [True, True]}, 674 ), 675 ) 676 677 def test_any(self): 678 def test_multiple_types(a): 679 assert not isinstance(a, bool) 680 return a 681 682 def test_multiple_type_refinement(a): 683 if isinstance(a, bool): 684 return 1 685 elif isinstance(a, int): 686 return 1 + a 687 elif isinstance(a, float): 688 return 1 + int(a) 689 else: 690 return -1 691 692 make_global(test_multiple_types, test_multiple_type_refinement) 693 694 scripted_fn = torch.jit.script( 695 test_multiple_types, example_inputs=[(1,), ("abc",), (8.9,), ([3, 4, 5],)] 696 ) 697 self.assertEqual(scripted_fn(10), test_multiple_types(10)) 698 self.assertEqual(scripted_fn("def"), test_multiple_types("def")) 699 self.assertEqual(scripted_fn(7.89999), test_multiple_types(7.89999)) 700 self.assertEqual(scripted_fn([10, 11, 14]), test_multiple_types([10, 11, 14])) 701 702 scripted_fn = torch.jit.script( 703 test_multiple_type_refinement, 704 example_inputs=[ 705 (1,), 706 ("abc",), 707 (8.9,), 708 ([3, 4, 5],), 709 (True,), 710 ({"a": True},), 711 ], 712 ) 713 self.assertEqual(scripted_fn(10), test_multiple_type_refinement(10)) 714 self.assertEqual(scripted_fn("def"), test_multiple_type_refinement("def")) 715 self.assertEqual(scripted_fn(7.89999), test_multiple_type_refinement(7.89999)) 716 self.assertEqual( 717 scripted_fn([10, 11, 14]), test_multiple_type_refinement([10, 11, 14]) 718 ) 719 self.assertEqual(scripted_fn(False), test_multiple_type_refinement(False)) 720 self.assertEqual( 721 scripted_fn({"abc": True, "def": False}), 722 test_multiple_type_refinement({"abc": True, "def": False}), 723 ) 724 725 def test_class_as_profiled_types(self): 726 class UserDefinedClass: 727 def fn(self, b) -> Any: 728 assert b is not None 729 if isinstance(b, int): 730 return b if b > 0 else -1 731 elif isinstance(b, float): 732 return b if b > 0.0 else -1.0 733 return 0 734 735 def test_model(a, m): 736 assert not isinstance(a, bool) 737 return m.fn(a) 738 739 make_global(UserDefinedClass, test_model) 740 741 user_class = UserDefinedClass() 742 scripted_fn = torch.jit.script( 743 test_model, 744 example_inputs=[ 745 ( 746 10, 747 user_class, 748 ), 749 ( 750 10.9, 751 user_class, 752 ), 753 ], 754 ) 755 self.assertEqual( 756 scripted_fn( 757 100, 758 user_class, 759 ), 760 test_model(100, user_class), 761 ) 762 self.assertEqual( 763 scripted_fn( 764 1.9, 765 user_class, 766 ), 767 test_model(1.9, user_class), 768 ) 769 770 def test_class_with_args_as_profiled_types(self): 771 class ClassWithArgs: 772 def __init__(self, a: bool): 773 self.a = a 774 775 def fn(self, b): 776 if self.a: 777 return b 778 else: 779 return -1 780 781 def test_model_with_args(a, m): 782 assert not isinstance(a, bool) 783 return m.fn(a) 784 785 make_global(ClassWithArgs, test_model_with_args) 786 787 user_class = ClassWithArgs(False) 788 scripted_fn = torch.jit.script( 789 test_model_with_args, 790 example_inputs=[ 791 ( 792 10, 793 user_class, 794 ), 795 ( 796 10.9, 797 user_class, 798 ), 799 ], 800 ) 801 self.assertEqual( 802 scripted_fn( 803 100, 804 ClassWithArgs(True), 805 ), 806 test_model_with_args(100, ClassWithArgs(True)), 807 ) 808 809 def test_nn_parameter_as_arg(self): 810 class TestNNParameter(torch.nn.Module): 811 def __init__(self) -> None: 812 super().__init__() 813 self.inp = torch.nn.Parameter(torch.ones(2, 3)) 814 815 def add_nn_parameter_with_int(self, x, y): 816 return torch.add(x, y) 817 818 def forward(self, y): 819 return self.add_nn_parameter_with_int(self.inp, y) 820 821 make_global(TestNNParameter) 822 pdt_model = TestNNParameter() 823 scripted_fn = torch.jit.script( 824 pdt_model, 825 example_inputs={ 826 pdt_model: [ 827 (10,), 828 ], 829 }, 830 ) 831 self.assertEqual(scripted_fn(20), pdt_model(20)) 832 833 def test_fx_tracing_with_typing(self): 834 class FXModelOutput(NamedTuple): 835 result: List[int] 836 837 class FXModel(torch.nn.Module): 838 def forward(self, a) -> FXModelOutput: 839 result = FXModelOutput(result=a) 840 return result 841 842 make_global(FXModel, FXModelOutput) 843 pdt_model = FXModel() 844 scripted_fn = torch.jit.script( 845 pdt_model, 846 example_inputs={ 847 pdt_model: [ 848 ( 849 [ 850 10, 851 20, 852 ], 853 ), 854 ], 855 }, 856 ) 857 self.assertEqual(scripted_fn([20]), pdt_model([20])) 858 859 def test_nonetype_as_optional_of_type(self): 860 def test_none(a) -> Any: 861 if a is None: 862 return 0 863 else: 864 return a + torch.ones(1) 865 866 make_global(test_none) 867 868 scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10.6,)]) 869 self.assertEqual( 870 scripted_fn( 871 30.9, 872 ), 873 test_none( 874 30.9, 875 ), 876 ) 877 878 scripted_fn = torch.jit.script(test_none, example_inputs=[(None,), (10,)]) 879 self.assertEqual( 880 scripted_fn( 881 2, 882 ), 883 test_none( 884 2, 885 ), 886 ) 887 888 scripted_fn = torch.jit.script( 889 test_none, example_inputs=[(None,), (torch.Tensor(1),)] 890 ) 891 self.assertEqual( 892 scripted_fn( 893 torch.ones(1), 894 ), 895 test_none( 896 torch.ones(1), 897 ), 898 ) 899