1# Owner(s): ["oncall: jit"] 2 3import inspect 4import os 5import sys 6import types 7import unittest 8from collections import defaultdict, OrderedDict 9from textwrap import dedent 10from typing import Any, Dict, List, NamedTuple, Optional, Tuple 11 12import torch 13import torch.nn as nn 14from torch import Tensor 15from torch.testing import FileCheck 16 17 18# Make the helper files in test/ importable 19pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 20sys.path.append(pytorch_test_dir) 21from torch.testing._internal.common_utils import skipIfTorchDynamo, TEST_CUDA 22from torch.testing._internal.jit_utils import JitTestCase, make_global 23 24 25if __name__ == "__main__": 26 raise RuntimeError( 27 "This test file is not meant to be run directly, use:\n\n" 28 "\tpython test/test_jit.py TESTNAME\n\n" 29 "instead." 30 ) 31 32 33class TestList(JitTestCase): 34 def test_list_bool_conversion(self): 35 def if_predicate(l: List[int]): 36 if l: 37 s = 0 38 for n in l: 39 s += n 40 41 return s 42 else: 43 return -1 44 45 self.checkScript(if_predicate, ([1, 2, 3],)) 46 self.checkScript(if_predicate, ([],)) 47 48 def while_predicate(l: List[int]): 49 s = 0 50 51 while l: 52 s += l.pop() 53 54 self.checkScript(while_predicate, ([1, 2, 3],)) 55 self.checkScript(while_predicate, ([],)) 56 57 def ternary_predicate(l: List[int]): 58 return "non-empty" if l else "empty" 59 60 self.checkScript(ternary_predicate, ([1, 2, 3],)) 61 self.checkScript(ternary_predicate, ([],)) 62 63 def test_in_check(self): 64 def int_in(x: List[int]) -> bool: 65 return 2 in x 66 67 self.checkScript(int_in, ([1, 2, 3],)) 68 self.checkScript(int_in, ([1, 3, 3],)) 69 70 def float_in(x: List[float]) -> bool: 71 return 2.0 in x 72 73 self.checkScript(float_in, ([1.0, 2.0, 3.0],)) 74 self.checkScript(float_in, ([1.0, 3.0, 3.0],)) 75 76 def str_in(x: List[str]) -> bool: 77 return "hi" in x 78 79 self.checkScript(str_in, (["not", "here"],)) 80 self.checkScript(str_in, (["hi", "bye"],)) 81 self.checkScript(str_in, ([],)) 82 83 def test_list_literal(self): 84 def reassign(): 85 x = [1] 86 if 1 == 1: 87 x = [2, 3] 88 return 89 90 self.checkScript(reassign, (), optimize=False) 91 92 def reassign_arity_change(): 93 x = [1] 94 if 1 == 1: 95 x = [1, 2, 3] 96 return 97 98 self.checkScript(reassign_arity_change, (), optimize=False) 99 100 def reassign_from_empty_literal(): 101 x = [] 102 if 1 == 1: 103 x = [1, 2, 3] 104 return 105 106 with self.assertRaisesRegexWithHighlight( 107 RuntimeError, r"previously had type List\[Tensor\]", "x" 108 ): 109 self.checkScript(reassign_from_empty_literal, (), optimize=False) 110 111 def reassign_from_empty_builtin(): 112 x = torch.jit.annotate(List[int], []) 113 if 1 == 1: 114 x = [1, 2, 3] 115 y = torch.jit.annotate(List[float], []) 116 if 1 == 1: 117 y = [1.0, 2.0, 3.0] 118 z = [] 119 if 1 == 1: 120 z = [torch.randn([1])] 121 return 122 123 self.checkScript(reassign_from_empty_builtin, (), optimize=False) 124 125 def reassign_bad_type(): 126 x = [1] 127 if 1 == 1: 128 x = [1.0] 129 return 130 131 with self.assertRaisesRegexWithHighlight( 132 RuntimeError, "previously had type", "x" 133 ): 134 self.checkScript(reassign_bad_type, (), optimize=False) 135 136 def reassign_nested(): 137 x = torch.jit.annotate(List[int], []) 138 if 1 == 1: 139 x = [1, 2, 3] 140 if 1 == 1: 141 x = [1.0] 142 return 143 144 with self.assertRaisesRegexWithHighlight( 145 RuntimeError, "previously had type", "x" 146 ): 147 self.checkScript(reassign_nested, (), optimize=False) 148 149 def test_list_variance(self): 150 """ 151 `List[T1]` is not a subtype of `List[T2]`, even if `T1` is a 152 subtype of `T2`. However, if we have a temporary list object 153 (that is, a list comprehension or a list literal) on the rhs of 154 an assignment statement, we want to ignore the inferred type of 155 the rhs if we can prove that: 1) both the lhs and the rhs are 156 lists, and 2) the inner type of the lhs list is a subtype of the 157 inner type of the rhs list. 158 159 # This should pass 160 x: List[Optional[int]] = [None, None, None] 161 162 # This should fail 163 y: List[None] = [None, None, None] 164 x: List[Optional[int]] = y 165 """ 166 167 def test_listliteral_is_typed_from_annotation(): 168 x: List[Optional[int]] = [None, None, None] 169 return x 170 171 self.checkScript(test_listliteral_is_typed_from_annotation, ()) 172 173 def test_listcomprehension_is_typed_from_annotation(): 174 x: List[Optional[int]] = [None for _ in range(3)] 175 return x 176 177 self.checkScript(test_listcomprehension_is_typed_from_annotation, ()) 178 179 def test_lists_with_different_internal_types_are_invariant(self): 180 x: List[int] = [1, 2, 3] 181 y: List[Optional[int]] = x 182 return x 183 184 with self.assertRaisesRegex( 185 RuntimeError, 186 "Variable 'y' is " 187 "annotated with type " 188 r"List\[Optional\[int\]\] but is " 189 "being assigned to a value of type " 190 r"List\[int\]", 191 ): 192 torch.jit.script(test_lists_with_different_internal_types_are_invariant) 193 194 def test_lists_with_different_internal_types_are_invariant_recursive(self): 195 x: List[List[int]] = [[1, 2], [3]] 196 y: List[List[Optional[int]]] = x 197 return x 198 199 with self.assertRaisesRegex( 200 RuntimeError, 201 "Variable 'y' is " 202 "annotated with type " 203 r"List\[List\[Optional\[int\]\]\] " 204 "but is being assigned to a value " 205 r"of type List\[List\[int\]\]", 206 ): 207 torch.jit.script( 208 test_lists_with_different_internal_types_are_invariant_recursive 209 ) 210 211 def test_del(self): 212 def inputs(): 213 return [1, 2, 3, 4] 214 215 def fn(x: List[int]) -> List[int]: 216 del x[1] 217 return x 218 219 python_out = fn(inputs()) 220 # checkScript reuses the same object, but here it's being mutated so do 221 # it manually 222 cu = torch.jit.CompilationUnit() 223 cu.define(dedent(inspect.getsource(fn))) 224 self.assertEqual(cu.fn(inputs()), python_out) 225 self.assertEqual(torch.jit.script(fn)(inputs()), python_out) 226 227 @torch.jit.script 228 def fn2(x: List[int]) -> List[int]: 229 del x[100] 230 return x 231 232 with self.assertRaisesRegexWithHighlight( 233 RuntimeError, "out of range", "x[100]" 234 ): 235 fn2([]) 236 237 with self.assertRaisesRegexWithHighlight( 238 RuntimeError, "deletion at a single index", "x[1:3]" 239 ): 240 241 @torch.jit.script 242 def fn(x: List[int]) -> List[int]: 243 del x[1:3] 244 return x 245 246 def test_list_keyword(self): 247 def foo(): 248 return ( 249 list([1, 2, 3]), # noqa: C410 250 list(("a", "b")), # noqa: C410 251 list(range(5)), 252 list("abcdefg"), 253 ) 254 255 self.checkScript(foo, ()) 256 257 def foo2(): 258 x: List[int] = list() # noqa: C408 259 x.append(1) 260 return (x,) 261 262 self.checkScript(foo2, ()) 263 264 def foo3(): 265 return list(list("abc")) # noqa: C414 266 267 self.checkScript(foo3, ()) 268 FileCheck().check_count("aten::list", 2, exactly=True).run( 269 torch.jit.script(foo3).graph 270 ) 271 272 def test_dict_keyword_with_kwargs(self): 273 def fn(): 274 return dict(foo=1, bar=2, baz=3) 275 276 self.checkScript(fn, ()) 277 278 def test_dict_keyword_with_kwargs_using_container_values(self): 279 def fn(): 280 return dict(foo=[1, 2, 3], bar=[4, 5, 6], baz=[7, 8, 9]) 281 282 self.checkScript(fn, ()) 283 284 def test_dict_keyword_with_iterable(self): 285 def fn(): 286 return dict([("foo", 1), ("bar", 2), ("baz", 3)]) # noqa: C406 287 288 self.checkScript(fn, ()) 289 290 def test_dict_keyword_with_empty_iterable(self): 291 def fn(): 292 return dict([]) # noqa: C406 293 294 self.checkScript(fn, ()) 295 296 def test_dict_keyword_with_internal_aggregate_function(self): 297 def fn(): 298 return dict(zip(["foo", "baz", "bar"], [1, 2, 3])) 299 300 self.checkScript(fn, ()) 301 302 def test_dict_keyword_with_mapping(self): 303 def fn(): 304 return {"foo": 1, "bar": 2, "baz": 3} 305 306 self.checkScript(fn, ()) 307 308 def test_dict_keyword_with_mapping_and_kwargs(self): 309 def fn(): 310 return dict({"foo": 1, "bar": 2}, baz=3) 311 312 self.checkScript(fn, ()) 313 314 def test_dict_keyword_with_dict_comprehension(self): 315 def fn(): 316 return {i: chr(i + 65) for i in range(4)} 317 318 self.checkScript(fn, ()) 319 320 def test_dict_keyword_with_dict_comprehension_and_kwargs(self): 321 def fn(): 322 return dict({chr(65 + i): i for i in range(4)}, foo=2) 323 324 self.checkScript(fn, ()) 325 326 def test_dict_keyword_with_empty_dict_comprehension(self): 327 def fn(): 328 return {} 329 330 self.checkScript(fn, ()) 331 332 def test_dict_keyword_is_correctly_typed(self): 333 def fn(): 334 x: Dict[str, int] = dict() # noqa: C408 335 x["foo"] = 1 336 return x 337 338 self.checkScript(fn, ()) 339 340 def test_dict_keyword_with_mismatched_annotations(self): 341 err_msg = ( 342 r"Dict type annotation `Dict\[int, str\]` did not " 343 "match the type of an actual key type `str`" 344 ) 345 with self.assertRaisesRegex(RuntimeError, err_msg): 346 347 @torch.jit.script 348 def fn(): 349 x: Dict[int, str] = dict( # noqa: C406 350 [("foo", 1), ("bar", 2), ("baz", 3)] 351 ) 352 return x 353 354 def test_dict_keyword_with_nested_call(self): 355 def fn(): 356 return dict(dict(foo=1, bar=2, baz=3)) 357 358 self.checkScript(fn, ()) 359 360 def test_dict_keyword_with_previously_declared_variable(self): 361 def fn(): 362 d = {"foo": 1, "bar": 2} 363 return dict(d) 364 365 self.checkScript(fn, ()) 366 367 def test_dict_keyword_with_previously_declared_variable_and_kwargs(self): 368 def fn(): 369 d = {"foo": 1, "bar": 2} 370 return dict(d, baz=3) 371 372 self.checkScript(fn, ()) 373 374 def test_min_bool_list(self): 375 def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]: 376 return min(a, b) 377 378 self.checkScript(jit_min_list, ([True, False], [False, True])) 379 380 def test_min_max_list(self): 381 def jit_min_list(a: List[int], b: List[int]) -> List[int]: 382 return min(a, b) 383 384 def jit_min_list_float(a: List[float], b: List[float]) -> List[float]: 385 return min(a, b) 386 387 def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]: 388 return min(a, b) 389 390 def run_tests(func, a, b): 391 for t in zip(a, b): 392 self.checkScript(func, t) 393 394 args_left_int = [[1, 8, 8], [2, 1, 1], [], [2], [1], [1, 2, 3]] 395 args_right_int = [[2, 1, 1], [1, 8, 8], [], [1], [], [1, 2]] 396 run_tests(jit_min_list, args_left_int, args_right_int) 397 398 args_left_float = [ 399 [1.0, 8.0, 8.0], 400 [2.0, 1.0, 1.0], 401 [], 402 [2.0], 403 [1.0], 404 [1.0, 2.0, 3.0], 405 ] 406 args_right_float = [[2.0, 1.0, 1.0], [1.0, 8.0, 8.0], [], [1.0], [], [1.0, 2.0]] 407 run_tests(jit_min_list_float, args_left_float, args_right_float) 408 409 args_left_bool = [ 410 [], 411 [], 412 [], 413 [False], 414 [True], 415 [False, True], 416 [True, True], 417 [False, False, False], 418 [False, False, True], 419 ] 420 args_right_bool = [ 421 [], 422 [False], 423 [True], 424 [True], 425 [False], 426 [True, True], 427 [False, True], 428 [False, False, True], 429 [False, False, False], 430 ] 431 run_tests(jit_min_list_bool, args_left_bool, args_right_bool) 432 433 def jit_max_list(a: List[int], b: List[int]) -> List[int]: 434 return max(a, b) 435 436 def jit_max_list_float(a: List[float], b: List[float]) -> List[float]: 437 return max(a, b) 438 439 def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]: 440 return max(a, b) 441 442 args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]] 443 args_right_int = [[8, 1, 1], [1, 8, 8], [], [2], [1], [1, 2, 3]] 444 run_tests(jit_max_list, args_left_int, args_right_int) 445 446 args_left_float = [[1.0, 8.0, 8.0], [8.0, 1.0, 1.0], [], [1.0], [], [1.0, 2.0]] 447 args_right_float = [ 448 [8.0, 1.0, 1.0], 449 [1.0, 8.0, 8.0], 450 [], 451 [2.0], 452 [1.0], 453 [1.0, 2.0, 3.0], 454 ] 455 run_tests(jit_max_list_float, args_left_float, args_right_float) 456 457 run_tests(jit_max_list_bool, args_left_bool, args_right_bool) 458 459 def test_list_gather(self): 460 def index(): 461 a = [1, 2, 3] 462 return a[1] 463 464 self.checkScript(index, ()) 465 466 def negative_index(): 467 a = [1, 2, 3] 468 return a[-1] 469 470 self.checkScript(negative_index, ()) 471 472 def bad_index(): 473 a = [1, 2, 3] 474 return a[4] 475 476 self.checkScriptRaisesRegex(bad_index, (), Exception, "list index out of range") 477 478 def bad_negative_index(): 479 a = [1, 2, 3] 480 return a[-5] 481 482 self.checkScriptRaisesRegex( 483 bad_negative_index, (), Exception, "list index out of range" 484 ) 485 486 def test_list_len(self): 487 def func(): 488 a = [1, 2, 3] 489 return len(a) == 3 490 491 self.checkScript(func, ()) 492 493 def func2(): 494 a = [] 495 return len(a) == 0 496 497 self.checkScript(func2, ()) 498 499 @skipIfTorchDynamo( 500 "TorchDynamo fails to raise on this checkScriptRaisesRegex, because we trace it properly now" 501 ) 502 def test_list_ops(self): 503 def test_equality(): 504 a = [1, 2, 3] 505 b = [1, 2, 3] 506 return a == b 507 508 self.checkScript(test_equality, (), optimize=True) 509 510 def test_equality_str(): 511 a = ["foo", "bar"] 512 b = ["foo", "bar"] 513 return a == b 514 515 self.checkScript(test_equality_str, (), optimize=True) 516 517 def test_inequality(): 518 a = [1, 2, 3] 519 b = [1, 2, 3] 520 return a != b 521 522 self.checkScript(test_inequality, (), optimize=True) 523 524 def test_inequality_str(): 525 a = ["foo", "bar"] 526 b = ["foo", "bar", "food"] 527 return a != b 528 529 self.checkScript(test_inequality_str, (), optimize=True) 530 531 def test_non_equality(): 532 a = [1, 2, 3] 533 b = [3] 534 return a == b 535 536 self.checkScript(test_non_equality, (), optimize=True) 537 538 def test_non_inequality(): 539 a = [1, 2, 3] 540 b = [3] 541 return a != b 542 543 self.checkScript(test_non_equality, (), optimize=True) 544 545 def test_list_equality_as_cond(): 546 a = [1, 2, 3] 547 b = [3] 548 if a == b: 549 c = 1 550 else: 551 c = 2 552 return c 553 554 self.checkScript(test_list_equality_as_cond, (), optimize=True) 555 556 def test_list_add(): 557 a = [1, 2, 3] 558 b = [2] 559 c = a + b 560 return c == [1, 2, 3, 2] 561 562 self.checkScript(test_list_add, (), optimize=True) 563 564 def test_list_add_empty(): 565 a = [1, 2, 3] 566 b = torch.jit.annotate(List[int], []) 567 c = a + b 568 return c == [1, 2, 3] 569 570 self.checkScript(test_list_add_empty, (), optimize=True) 571 572 def test_tensor_list_equality(): 573 t1 = torch.ones([1, 1]) 574 t2 = torch.ones([1, 1]) 575 x = [t1, t2] 576 y = [t2, t1] 577 return x == y 578 579 self.checkScript(test_tensor_list_equality, (), optimize=True) 580 581 def test_invalid_list_equality(): 582 t1 = torch.ones([2, 2]) 583 t2 = torch.ones([2, 2]) 584 x = [t1, t2] 585 y = [t2, t1] 586 # will throw since the tensors have more than one element 587 return x == y 588 589 self.checkScriptRaisesRegex( 590 test_invalid_list_equality, (), RuntimeError, "Boolean value of Tensor" 591 ) 592 593 def test_list_sort(self): 594 template = dedent( 595 """ 596 def func(): 597 li_1 = {list_create} 598 li_2 = {list_create} 599 li_3 = {list_create} 600 li_1.sort() 601 li_2.sort(reverse=True) 602 li_4 = sorted(li_3) 603 return li_1, li_2, li_3, li_4 604 """ 605 ) 606 607 lists = [ 608 "[]", 609 "[1, 3, 2]", 610 "[True, False, True]", 611 "[1.2, .2, 3.2]", 612 "[torch.tensor(1.0), torch.tensor(0.2), torch.tensor(0.5)]", 613 "[torch.tensor(5), torch.tensor(-2), torch.tensor(4)]", 614 ] 615 for li in lists: 616 code = template.format(list_create=li) 617 scope = {} 618 exec(code, globals(), scope) 619 cu = torch.jit.CompilationUnit(code) 620 t1 = cu.func() 621 t2 = scope["func"]() 622 self.assertEqual(t1, t2) 623 624 def test_fail(x: List[Tensor]) -> List[Tensor]: 625 x.sort() 626 return x 627 628 self.checkScriptRaisesRegex( 629 test_fail, 630 (([torch.zeros([2]), torch.zeros([2])],)), 631 Exception, 632 "Boolean value of Tensor with more than one value", 633 ) 634 635 @torch.jit.script 636 def test_mutation(): 637 a = [1, 2, 3] 638 a.sort() 639 return a 640 641 test_mutation() 642 FileCheck().check("aten::sort").run(test_mutation.graph_for()) 643 644 def test_sorted_copy(): 645 a = [torch.tensor(2), torch.tensor(0), torch.tensor(1)] 646 b = sorted(a) 647 a[0] = torch.tensor(10) 648 return a, b 649 650 self.checkScript(test_sorted_copy, ()) 651 652 def test_list_slice(self): 653 def test_regular_slice(): 654 a = [0, 1, 2, 3, 4] 655 return a[2:3] == [2] 656 657 self.checkScript(test_regular_slice, ()) 658 659 def test_open_ended_slice(): 660 a = [0, 1, 2, 3, 4] 661 return a[2:] == [2, 3, 4] 662 663 self.checkScript(test_open_ended_slice, ()) 664 665 def test_open_ended_slice2(): 666 a = [0, 1, 2, 3, 4] 667 return a[:2] == [0, 1] 668 669 self.checkScript(test_open_ended_slice2, ()) 670 671 def test_negative_slice(): 672 a = [0, 1, 2, 3, 4] 673 return a[:-1] == [0, 1, 2, 3] 674 675 self.checkScript(test_negative_slice, ()) 676 677 def test_negative_slice2(): 678 a = [0, 1, 2, 3, 4] 679 return a[-3:-1] == [2, 3] 680 681 self.checkScript(test_negative_slice2, ()) 682 683 def test_backward_slice(): 684 a = [0, 1, 2, 3, 4] 685 return a[3:2] == torch.jit.annotate(List[int], []) 686 687 self.checkScript(test_backward_slice, ()) 688 689 def test_over_slice(): 690 a = [0, 1, 2, 3, 4] 691 return a[3:10] == [3, 4] 692 693 self.checkScript(test_backward_slice, ()) 694 695 def test_slice_index(self): 696 a = torch.tensor( 697 [ 698 [[1, 11], [2, 22]], 699 [[3, 33], [4, 44]], 700 [[5, 55], [6, 66]], 701 ] 702 ) 703 704 def test_index_slice1(x): 705 x = x[:, :, [0, 1]] 706 return x 707 708 self.checkScript(test_index_slice1, (a,)) 709 710 def test_index_slice2(x): 711 x = x[[2, 1, 0], :, :] 712 return x 713 714 self.checkScript(test_index_slice2, (a,)) 715 716 def test_index_slice3(x): 717 x = x[[0, 1], :, [1]] 718 return x 719 720 self.checkScript(test_index_slice3, (a,)) 721 722 def test_index_slice_empty_list(x): 723 empty_list: List[int] = [] 724 x = x[empty_list, :, :] 725 return x 726 727 self.checkScript(test_index_slice_empty_list, (a,)) 728 729 def test_index_slice_out_of_bounds_index(x): 730 x = x[[4], :, :] 731 return x 732 733 with self.assertRaisesRegexWithHighlight( 734 RuntimeError, 735 "index 4 is out of bounds for dimension 0 with size 3", 736 "x[[4], :, :]", 737 ): 738 self.checkScript(test_index_slice_out_of_bounds_index, (a,)) 739 740 def test_mutable_list_append(self): 741 def test_append(): 742 a = [0, 1] 743 a.append(2) 744 a.append(3) 745 return a == [0, 1, 2, 3] 746 747 self.checkScript(test_append, ()) 748 749 def test_comprehensions_basic(self): 750 def comp(l: List[int]) -> List[int]: 751 n = [x * 3 for x in l] 752 return n 753 754 comp([1, 2, 3]) 755 self.checkScript(comp, ([1, 2, 3],)) 756 757 def test_comprehensions_basic_float(self): 758 def comp(l: List[float]) -> List[float]: 759 n = [x * 3 for x in l] 760 return n 761 762 self.checkScript(comp, ([1.0, 2.0, 3.0],)) 763 764 def test_comprehensions_two_comps(self): 765 @torch.jit.script 766 def comp(l1: List[int], l2: List[int]) -> List[int]: 767 n = [x * 3 for x in l1] 768 n2 = [x + 2 for x in l2] 769 return n + n2 770 771 self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7]) 772 773 def test_comprehension_out_type_not_in_type(self): 774 def list_cast() -> int: 775 li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]] 776 return li[0] + li[1] + li[2] 777 778 self.checkScript(list_cast, ()) 779 780 def test_comprehension_iterable(self): 781 def test_func(fn, inputs): 782 self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs)) 783 784 def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]: 785 return [(k + 5, v - 2) for k, v in zip(names, results)] 786 787 test_func(foo, ([1, 2, 4], [4, 7, 9])) 788 test_func(foo, ([5], [4, 7, 9])) 789 790 def fn(x: int) -> List[int]: 791 return [i for i in range(x)] # noqa: C416 792 793 test_func(fn, (9,)) 794 test_func(fn, (0,)) 795 test_func(fn, (-1,)) 796 797 def changes_type(): 798 a = [float(i) for i in range(5)] 799 b = [float(i) for i in [1, 2, 3, 4]] 800 c = [(float(i), j) for i, j in enumerate([1, 2, 3, 8])] 801 return a, b, c 802 803 test_func(changes_type, ()) 804 805 def test_zero_iter(): 806 return [str(i) for i, j in zip("", "")] 807 808 test_func(test_zero_iter, ()) 809 810 def test_mutable_list_append_2(self): 811 def test_append_2(): 812 a = [0, 1] 813 a.append(2) 814 a = [1] 815 a.append(4) 816 return a == [1, 4] 817 818 self.checkScript(test_append_2, ()) 819 820 def test_mutable_list_append_if(self): 821 def test_append_if(): 822 a = [1] 823 if 1 == 1: 824 a.append(4) 825 return a == [1, 4] 826 827 self.checkScript(test_append_if, ()) 828 829 def test_mutable_list_append_if_else(self): 830 def test_append_if_else(): 831 a = [1] 832 if 1 == 2: 833 a.append(4) 834 else: 835 a.append(10) 836 return a == [1, 10] 837 838 self.checkScript(test_append_if_else, ()) 839 840 def test_mutable_list_append_loop(self): 841 def test_append_loop(): 842 a = torch.jit.annotate(List[int], []) 843 for i in range(5): 844 a.append(i) 845 846 return a == [0, 1, 2, 3, 4] 847 848 self.checkScript(test_append_loop, ()) 849 850 def test_mutable_list_append_loop_if(self): 851 def test_append_loop_if(): 852 a = torch.jit.annotate(List[int], []) 853 for i in range(5): 854 if i > 3: 855 a.append(i) 856 else: 857 a.append(0) 858 859 return a == [0, 0, 0, 0, 4] 860 861 self.checkScript(test_append_loop_if, ()) 862 863 def test_mutable_list_nested_loop(self): 864 def test_nested_loop(): 865 a = torch.jit.annotate(List[int], []) 866 for i in range(2): 867 for j in range(2): 868 a.append(i + j) 869 870 return a == [0, 1, 1, 2] 871 872 self.checkScript(test_nested_loop, ()) 873 874 def test_mutable_list_function_inline(self): 875 @torch.jit.script 876 def bar(y: List[int]) -> None: 877 y.append(4) 878 879 @torch.jit.script 880 def foo(): 881 x = [1, 2, 3] 882 bar(x) 883 return x 884 885 self.assertEqual(foo(), [1, 2, 3, 4]) 886 887 def test_mutable_list_reverse_empty(self): 888 def test_reverse_empty(): 889 a = [] 890 a.reverse() 891 892 return a == [] 893 894 self.checkScript(test_reverse_empty, ()) 895 896 def test_mutable_list_reverse(self): 897 def test_reverse(): 898 a = [1, 2, 3, 4] 899 a.reverse() 900 901 return a == [4, 3, 2, 1] 902 903 self.checkScript(test_reverse, ()) 904 905 def test_mutable_tensor_list_reverse(self): 906 def test_tensor_reverse(): 907 a = [torch.tensor(1), torch.tensor(2)] 908 a.reverse() 909 910 return a == [torch.tensor(2), torch.tensor(1)] 911 912 self.checkScript(test_tensor_reverse, ()) 913 914 def test_mutable_list_pop_empty(self): 915 @torch.jit.script 916 def test_pop_empty(): 917 a = torch.jit.annotate(List[int], []) 918 return a.pop() 919 920 with self.assertRaisesRegexWithHighlight( 921 RuntimeError, "pop from empty list", "a.pop" 922 ): 923 test_pop_empty() 924 925 def test_mutable_list_pop(self): 926 def test_pop(): 927 a = [1, 2, 3, 4] 928 b = a.pop() 929 930 return b == 4 931 932 self.checkScript(test_pop, ()) 933 934 def test_mutable_list_pop2(self): 935 def test_pop2(): 936 a = [1, 2, 3, 4] 937 b = a.pop() 938 939 return len(a) == 3 940 941 self.checkScript(test_pop2, ()) 942 943 def test_mutable_list_pop_at(self): 944 def test_pop_at(): 945 a = [1, 2, 3, 4] 946 b = a.pop(1) 947 948 return b == 2 949 950 self.checkScript(test_pop_at, ()) 951 952 def test_mutable_list_pop_at2(self): 953 def test_pop_at2(): 954 a = [1, 2, 3, 4] 955 b = a.pop(1) 956 957 return len(a) == 3 958 959 self.checkScript(test_pop_at2, ()) 960 961 def test_mutable_list_pop_at_negative(self): 962 def test_pop_at_negative(): 963 a = [1, 2, 3, 4] 964 b = a.pop(-2) 965 966 return b == 3 967 968 self.checkScript(test_pop_at_negative, ()) 969 970 def test_mutable_list_pop_at_negative2(self): 971 def test_pop_at_negative2(): 972 a = [1, 2, 3, 4] 973 b = a.pop(-2) 974 975 return len(a) == 3 976 977 self.checkScript(test_pop_at_negative2, ()) 978 979 def test_mutable_list_pop_slice(self): 980 def test_pop_slice(): 981 a = [1, 2, 3, 4] 982 b = [1, 2, 3, 4] 983 984 a.pop() 985 b = b[:-1] 986 987 return a == b 988 989 self.checkScript(test_pop_slice, ()) 990 991 def test_mutable_list_clear_empty(self): 992 def test_clear_empty(): 993 a = torch.jit.annotate(List[int], []) 994 a.clear() 995 996 return len(a) == 0 997 998 self.checkScript(test_clear_empty, ()) 999 1000 def test_mutable_list_clear(self): 1001 def test_clear(): 1002 a = [1, 2, 3, 4] 1003 a.clear() 1004 1005 return len(a) == 0 1006 1007 self.checkScript(test_clear, ()) 1008 1009 def test_mutable_list_insert(self): 1010 def test_list_insert(): 1011 a = [1, 2, 3, 4] 1012 a.insert(2, 5) 1013 1014 return a == [1, 2, 5, 3, 4] 1015 1016 self.checkScript(test_list_insert, ()) 1017 1018 def test_mutable_list_insert_negative(self): 1019 def test_list_insert_negative(): 1020 a = [1, 2, 3, 4] 1021 a.insert(-1, 5) 1022 1023 return a == [1, 2, 3, 5, 4] 1024 1025 self.checkScript(test_list_insert_negative, ()) 1026 1027 def test_mutable_list_insert_neg_out_of_bounds(self): 1028 def test_list_insert_neg_out_of_bounds(): 1029 a = [1, 2, 3, 4] 1030 a.insert(-10, 5) 1031 1032 return a == [5, 1, 2, 3, 4] 1033 1034 self.checkScript(test_list_insert_neg_out_of_bounds, ()) 1035 1036 def test_mutable_list_insert_out_of_bounds(self): 1037 def test_list_insert_out_of_bounds(): 1038 a = [1, 2, 3, 4] 1039 a.insert(10, 5) 1040 1041 return a == [1, 2, 3, 4, 5] 1042 1043 self.checkScript(test_list_insert_out_of_bounds, ()) 1044 1045 def test_mutable_list_remove_not_existing(self): 1046 @torch.jit.script 1047 def test_list_remove_not_existing(): 1048 a = [1, 2, 3, 4] 1049 a.remove(5) 1050 1051 return a 1052 1053 with self.assertRaisesRegexWithHighlight( 1054 RuntimeError, "x not in list", "a.remove" 1055 ): 1056 test_list_remove_not_existing() 1057 1058 def test_mutable_list_remove(self): 1059 def test_list_remove(): 1060 a = [1, 2, 3, 4] 1061 a.remove(3) 1062 1063 return a == [1, 2, 4] 1064 1065 self.checkScript(test_list_remove, ()) 1066 1067 def test_str_list_remove(): 1068 a = ["foo", "bar"] 1069 a.remove("foo") 1070 1071 return a == ["bar"] 1072 1073 self.checkScript(test_str_list_remove, ()) 1074 1075 def test_list_index_not_existing(self): 1076 @torch.jit.script 1077 def list_index_not_existing(): 1078 a = [4, 1, 3, 2] 1079 i = a.index(5) 1080 1081 return i 1082 1083 with self.assertRaisesRegexWithHighlight( 1084 RuntimeError, "'5' is not in list", "a.index" 1085 ): 1086 list_index_not_existing() 1087 1088 def test_list_index(self): 1089 def list_index(): 1090 a = [4, 1, 3, 2] 1091 i = a.index(3) 1092 1093 return i == 2 1094 1095 self.checkScript(list_index, ()) 1096 1097 def list_str_index(): 1098 a = ["foo", "bar"] 1099 i = a.index("bar") 1100 1101 return i == 1 1102 1103 self.checkScript(list_str_index, ()) 1104 1105 def test_tensor_list_index(self): 1106 def tensor_list_index(): 1107 a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)] 1108 i = a.index(torch.tensor(3)) 1109 1110 return i == 2 1111 1112 self.checkScript(tensor_list_index, ()) 1113 1114 def test_tensor_list_index_not_existing(self): 1115 @torch.jit.script 1116 def tensor_list_index_not_existing(): 1117 a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)] 1118 i = a.index(torch.tensor(5)) 1119 1120 return i 1121 1122 with self.assertRaisesRegexWithHighlight( 1123 RuntimeError, "is not in list", "a.index" 1124 ): 1125 tensor_list_index_not_existing() 1126 1127 def test_list_count(self): 1128 def list_count(): 1129 a = [4, 1, 4, 2, 4] 1130 i = a.count(4) 1131 1132 return i == 3 1133 1134 self.checkScript(list_count, ()) 1135 1136 def list_str_count(): 1137 a = ["foo", "bar", "foo"] 1138 i = a.count("foo") 1139 1140 return i == 2 1141 1142 self.checkScript(list_str_count, ()) 1143 1144 def test_list_count_not_existing(self): 1145 def list_count_not_existing(): 1146 a = [4, 1, 4, 2, 4] 1147 i = a.count(5) 1148 1149 return i == 0 1150 1151 self.checkScript(list_count_not_existing, ()) 1152 1153 def test_tensor_list_count(self): 1154 def tensor_list_count(): 1155 a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)] 1156 i = a.count(torch.tensor(4)) 1157 1158 return i == 3 1159 1160 self.checkScript(tensor_list_count, ()) 1161 1162 def test_tensor_list_count_not_existing(self): 1163 def tensor_list_count_not_existing(): 1164 a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)] 1165 i = a.count(torch.tensor(5)) 1166 1167 return i == 0 1168 1169 self.checkScript(tensor_list_count_not_existing, ()) 1170 1171 def test_mutable_list_remove_tensor(self): 1172 def test_list_remove_tensor(): 1173 a = [torch.ones(1), torch.zeros(1), torch.ones(2)] 1174 a.remove(torch.zeros(1)) 1175 1176 return len(a) == 2 1177 1178 self.checkScript(test_list_remove_tensor, ()) 1179 1180 def test_mutable_list_remove2(self): 1181 def test_list_remove2(): 1182 a = [1] 1183 a.remove(1) 1184 1185 return len(a) == 0 1186 1187 self.checkScript(test_list_remove2, ()) 1188 1189 def test_extend_list_mutable(self): 1190 @torch.jit.script 1191 def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]: 1192 a.extend(b) 1193 return a 1194 1195 for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: 1196 for r in [ 1197 [], 1198 [torch.rand(2)], 1199 [torch.rand(2), torch.rand(2), torch.rand(2)], 1200 ]: 1201 self.assertEqual(extend_list(l, r), l + r) 1202 1203 def test_extend_list_immutable(self): 1204 @torch.jit.script 1205 def extend_list(a: List[int], b: List[int]) -> List[int]: 1206 a.extend(b) 1207 return a 1208 1209 for l in [[], [1], [1, 2, 3]]: 1210 for r in [[], [1], [1, 2, 3]]: 1211 self.assertEqual(extend_list(l, r), l + r) 1212 1213 def test_copy_list_mutable(self): 1214 @torch.jit.script 1215 def copy_list(a: List[Tensor]) -> List[Tensor]: 1216 return a.copy() 1217 1218 for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]: 1219 self.assertEqual(copy_list(l), l) 1220 1221 def test_copy_list_immutable(self): 1222 @torch.jit.script 1223 def copy_list(a: List[int]) -> List[int]: 1224 return a.copy() 1225 1226 for l in [[], [1], [1, 2, 3]]: 1227 self.assertEqual(copy_list(l), l) 1228 1229 def test_min_max_single_list(self): 1230 def min_intlist(li: List[int]) -> int: 1231 return min(li) 1232 1233 def max_intlist(li: List[int]) -> int: 1234 return max(li) 1235 1236 def min_boollist(li: List[bool]) -> bool: 1237 return min(li) 1238 1239 def max_boollist(li: List[bool]) -> bool: 1240 return max(li) 1241 1242 def min_floatlist(li: List[float]) -> float: 1243 return min(li) 1244 1245 def max_floatlist(li: List[float]) -> float: 1246 return max(li) 1247 1248 int_lists = [1], [2, 1, 2], [-3, 4, 2], [-2, -7, 1, 4], [2, 1, 0, 4], [] 1249 1250 def check_list(fn, li): 1251 if len(li) == 0: 1252 self.checkScriptRaisesRegex(fn, (li,), Exception, "empty") 1253 else: 1254 self.checkScript(fn, (li,)) 1255 1256 for int_list in int_lists: 1257 check_list(min_intlist, int_list) 1258 check_list(max_intlist, int_list) 1259 1260 bool_li = [bool(x) for x in int_list] 1261 check_list(min_boollist, bool_li) 1262 check_list(max_boollist, bool_li) 1263 1264 float_li = [float(x) for x in int_list] 1265 check_list(min_floatlist, float_li) 1266 check_list(max_floatlist, float_li) 1267 1268 def test_to_list(self): 1269 """Unit tests for Tensor.tolist() function.""" 1270 1271 """ 1272 Boolean dtype unit tests. 1273 """ 1274 1275 def to_list_bool_0D(x: torch.Tensor) -> bool: 1276 li = torch.jit.annotate(bool, x.tolist()) 1277 return li 1278 1279 def to_list_bool_1D(x: torch.Tensor) -> List[bool]: 1280 li = torch.jit.annotate(List[bool], x.tolist()) 1281 return li 1282 1283 def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]: 1284 li = torch.jit.annotate(List[List[bool]], x.tolist()) 1285 return li 1286 1287 def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]: 1288 li = torch.jit.annotate(List[List[List[bool]]], x.tolist()) 1289 return li 1290 1291 self.checkScript(to_list_bool_0D, (torch.tensor(False, dtype=torch.bool),)) 1292 bool_input_1D = torch.tensor([True, False, True, False], dtype=torch.bool) 1293 self.checkScript(to_list_bool_1D, (bool_input_1D,)) 1294 bool_input_2D = torch.tensor( 1295 [[True, True, False], [False, True, False]], dtype=torch.bool 1296 ) 1297 self.checkScript(to_list_bool_2D, (bool_input_2D,)) 1298 bool_input_3D = torch.tensor( 1299 [[[True, False], [False, True]], [[True, False], [False, False]]], 1300 dtype=torch.bool, 1301 ) 1302 self.checkScript(to_list_bool_3D, (bool_input_3D,)) 1303 bool_input_noncontiguous = torch.tensor( 1304 [[[True, False], [False, True]], [[True, False], [False, False]]], 1305 dtype=torch.bool, 1306 ).transpose(0, 1) 1307 self.checkScript(to_list_bool_3D, (bool_input_noncontiguous,)) 1308 1309 """ 1310 Int dtype unit tests. 1311 """ 1312 1313 def to_list_int_0D(x: torch.Tensor) -> int: 1314 li = torch.jit.annotate(int, x.tolist()) 1315 return li 1316 1317 def to_list_int_1D(x: torch.Tensor) -> List[int]: 1318 li = torch.jit.annotate(List[int], x.tolist()) 1319 return li 1320 1321 def to_list_int_2D(x: torch.Tensor) -> List[List[int]]: 1322 li = torch.jit.annotate(List[List[int]], x.tolist()) 1323 return li 1324 1325 def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]: 1326 li = torch.jit.annotate(List[List[List[int]]], x.tolist()) 1327 return li 1328 1329 self.checkScript(to_list_int_0D, (torch.tensor(1, dtype=torch.long),)) 1330 int_input_1D = torch.tensor([1, 2, 3, 4], dtype=torch.long) 1331 self.checkScript(to_list_int_1D, (int_input_1D,)) 1332 int_input_2D = torch.tensor([[1, 2, 3], [3, 4, 5]], dtype=torch.long) 1333 self.checkScript(to_list_int_2D, (int_input_2D,)) 1334 int_input_3D = torch.tensor( 1335 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long 1336 ) 1337 self.checkScript(to_list_int_3D, (int_input_3D,)) 1338 int_input_noncontiguous = torch.tensor( 1339 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long 1340 ).transpose(0, 1) 1341 self.checkScript(to_list_int_3D, (int_input_noncontiguous,)) 1342 1343 """ 1344 Float dtype unit tests. 1345 """ 1346 1347 def to_list_float_0D(x: torch.Tensor) -> float: 1348 li = torch.jit.annotate(float, x.tolist()) 1349 return li 1350 1351 def to_list_float_1D(x: torch.Tensor) -> List[float]: 1352 li = torch.jit.annotate(List[float], x.tolist()) 1353 return li 1354 1355 def to_list_float_2D(x: torch.Tensor) -> List[List[float]]: 1356 li = torch.jit.annotate(List[List[float]], x.tolist()) 1357 return li 1358 1359 def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]: 1360 li = torch.jit.annotate(List[List[List[float]]], x.tolist()) 1361 return li 1362 1363 # Test with torch.float dtype Tensors to check that they are converted to double automatically. 1364 self.checkScript(to_list_float_0D, (torch.randn(5, dtype=torch.float)[0],)) 1365 self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.float),)) 1366 self.checkScript(to_list_float_2D, (torch.randn(5, 6, dtype=torch.float),)) 1367 self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.float),)) 1368 self.checkScript( 1369 to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.float).transpose(0, 1),) 1370 ) 1371 1372 self.checkScript(to_list_float_0D, (torch.randn(5, dtype=torch.double)[0],)) 1373 self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.double),)) 1374 self.checkScript(to_list_float_2D, (torch.randn(5, 6, dtype=torch.double),)) 1375 self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.double),)) 1376 self.checkScript( 1377 to_list_float_3D, 1378 (torch.randn(5, 6, 7, dtype=torch.double).transpose(0, 1),), 1379 ) 1380 1381 """ 1382 Complex dtype unit tests. 1383 """ 1384 1385 def to_list_complex_0D(x: torch.Tensor) -> complex: 1386 li = torch.jit.annotate(complex, x.tolist()) 1387 return li 1388 1389 def to_list_complex_1D(x: torch.Tensor) -> List[complex]: 1390 li = torch.jit.annotate(List[complex], x.tolist()) 1391 return li 1392 1393 def to_list_complex_2D(x: torch.Tensor) -> List[List[complex]]: 1394 li = torch.jit.annotate(List[List[complex]], x.tolist()) 1395 return li 1396 1397 def to_list_complex_3D(x: torch.Tensor) -> List[List[List[complex]]]: 1398 li = torch.jit.annotate(List[List[List[complex]]], x.tolist()) 1399 return li 1400 1401 # Test with torch.complex dtype Tensors to check that they are converted to double automatically. 1402 self.checkScript(to_list_complex_0D, (torch.randn(5, dtype=torch.cfloat)[0],)) 1403 self.checkScript(to_list_complex_1D, (torch.randn(5, dtype=torch.cfloat),)) 1404 self.checkScript(to_list_complex_2D, (torch.randn(5, 6, dtype=torch.cfloat),)) 1405 self.checkScript( 1406 to_list_complex_3D, (torch.randn(5, 6, 7, dtype=torch.cfloat),) 1407 ) 1408 self.checkScript( 1409 to_list_complex_3D, 1410 (torch.randn(5, 6, 7, dtype=torch.cfloat).transpose(0, 1),), 1411 ) 1412 1413 self.checkScript(to_list_complex_0D, (torch.randn(5, dtype=torch.cdouble)[0],)) 1414 self.checkScript(to_list_complex_1D, (torch.randn(5, dtype=torch.cdouble),)) 1415 self.checkScript(to_list_complex_2D, (torch.randn(5, 6, dtype=torch.cdouble),)) 1416 self.checkScript( 1417 to_list_complex_3D, (torch.randn(5, 6, 7, dtype=torch.cdouble),) 1418 ) 1419 self.checkScript( 1420 to_list_complex_3D, 1421 (torch.randn(5, 6, 7, dtype=torch.cdouble).transpose(0, 1),), 1422 ) 1423 1424 """ 1425 Non-happy path tests: 1426 - missing type annotation 1427 - mismatch between type annotation and input 1428 - type annotation with unsupported type 1429 - type annotation with the wrong dimension 1430 - type annotation with scalar type that doesn't match the input scalar type 1431 """ 1432 1433 def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]: 1434 li = x.tolist() 1435 return li 1436 1437 def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]: 1438 li = torch.jit.annotate(float, x.tolist()) 1439 return li 1440 1441 def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]: 1442 li = torch.jit.annotate(List[str], x.tolist()) 1443 return li 1444 1445 def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]: 1446 li = torch.jit.annotate(List[List[float]], x.tolist()) 1447 return li 1448 1449 def to_list_type_annotation_incorrect_scalar_type( 1450 x: torch.Tensor, 1451 ) -> List[float]: 1452 li = torch.jit.annotate(List[float], x.tolist()) 1453 return li 1454 1455 with self.assertRaisesRegexWithHighlight( 1456 RuntimeError, r"Expected type hint for result of tolist()", "x.tolist(" 1457 ): 1458 self.checkScript(to_list_missing_type_annotation, (torch.randn(5),)) 1459 1460 with self.assertRaisesRegexWithHighlight( 1461 RuntimeError, 1462 r"Return value was annotated as having type List\[float\] but is actually of type float", 1463 "return li", 1464 ): 1465 self.checkScript(to_list_incorrect_type_annotation, (torch.randn(5),)) 1466 1467 with self.assertRaisesRegex( 1468 RuntimeError, r"str is not one of the supported element types for tolist" 1469 ): 1470 self.checkScript(to_list_unsupported_type_annotation, (torch.randn(5),)) 1471 1472 with self.assertRaisesRegex( 1473 RuntimeError, 1474 r"Output annotation list dimension and runtime tensor dimension must match", 1475 ): 1476 self.checkScript( 1477 to_list_type_annotation_wrong_dim, (torch.randn(5, dtype=torch.double),) 1478 ) 1479 1480 with self.assertRaisesRegex( 1481 RuntimeError, 1482 r"Output annotation element type and runtime tensor element type must match", 1483 ): 1484 self.checkScript( 1485 to_list_type_annotation_incorrect_scalar_type, 1486 (torch.ones(5, dtype=torch.long),), 1487 ) 1488 1489 @unittest.skipIf(not TEST_CUDA, "CUDA not available") 1490 def test_to_list_gpu(self): 1491 """GPU tests for Tensor.tolist() function.""" 1492 1493 def to_list_bool_1D(x: torch.Tensor) -> List[bool]: 1494 li = torch.jit.annotate(List[bool], x.tolist()) 1495 return li 1496 1497 def to_list_int_1D(x: torch.Tensor) -> List[int]: 1498 li = torch.jit.annotate(List[int], x.tolist()) 1499 return li 1500 1501 def to_list_float_1D(x: torch.Tensor) -> List[float]: 1502 li = torch.jit.annotate(List[float], x.tolist()) 1503 return li 1504 1505 self.checkScript( 1506 to_list_bool_1D, 1507 (torch.tensor([True, False, True, False], dtype=torch.bool).cuda(),), 1508 ) 1509 self.checkScript( 1510 to_list_int_1D, (torch.tensor([1, 2, 3, 4], dtype=torch.long).cuda(),) 1511 ) 1512 self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.double).cuda(),)) 1513 1514 def test_no_element_type_annotation(self): 1515 def fn_with_comment(x: torch.Tensor) -> List: 1516 a: List = x.tolist() 1517 return a 1518 1519 def annotated_fn(x: torch.Tensor) -> List: 1520 a: List = x.tolist() 1521 return a 1522 1523 with self.assertRaisesRegex( 1524 RuntimeError, r"Attempted to use List without a contained type" 1525 ): 1526 cu = torch.jit.CompilationUnit() 1527 cu.define(dedent(inspect.getsource(fn_with_comment))) 1528 1529 with self.assertRaisesRegex( 1530 RuntimeError, r"Attempted to use List without a contained type" 1531 ): 1532 cu = torch.jit.CompilationUnit() 1533 cu.define(dedent(inspect.getsource(annotated_fn))) 1534 1535 with self.assertRaisesRegex( 1536 RuntimeError, r"Attempted to use List without a contained type" 1537 ): 1538 torch.jit.script(fn_with_comment) 1539 1540 with self.assertRaisesRegex( 1541 RuntimeError, r"Attempted to use List without a contained type" 1542 ): 1543 torch.jit.script(annotated_fn) 1544 1545 def test_list_none(self): 1546 with self.assertRaisesRegex( 1547 RuntimeError, "Can not create ListType with None type" 1548 ): 1549 x = torch._C.ListType(None) 1550 1551 def test_list_unification_hint(self): 1552 with self.assertRaisesRegex( 1553 RuntimeError, "Expected an annotation of type List" 1554 ): 1555 1556 @torch.jit.script 1557 def x(): 1558 b: int = [2, 3] 1559 return b 1560 1561 1562class TestDict(JitTestCase): 1563 def dict(self): 1564 return {"a": torch.ones(1), "b": torch.ones(1) + 1, "c": torch.ones(1) + 2} 1565 1566 def dict2(self): 1567 return { 1568 "x": torch.ones(1) + 100, 1569 "y": torch.ones(1) + 101, 1570 "z": torch.ones(1) + 102, 1571 } 1572 1573 def dict_bool(self): 1574 return {True: 1} 1575 1576 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1577 def test_dict_bool_conversion(self): 1578 def if_predicate(d: Dict[int, int]): 1579 if d: 1580 s, t = 0, 0 1581 for k, v in d.items(): 1582 s += k 1583 t += v 1584 1585 return s, t 1586 else: 1587 return -1, -1 1588 1589 self.checkScript(if_predicate, ({1: 2, 3: 5},)) 1590 self.checkScript(if_predicate, ({},)) 1591 1592 def while_predicate(d: Dict[int, int]): 1593 while d: 1594 d.clear() 1595 1596 self.checkScript(while_predicate, ({1: 2, 3: 5},)) 1597 self.checkScript(while_predicate, ({},)) 1598 1599 def ternary_predicate(d: Dict[int, int]): 1600 return "non-empty" if d else "empty" 1601 1602 self.checkScript(ternary_predicate, ({1: 2, 3: 5},)) 1603 self.checkScript(ternary_predicate, ({},)) 1604 1605 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1606 def test_del(self): 1607 def inputs(): 1608 return {"hi": 2, "bye": 3} 1609 1610 def fn(x: Dict[str, int]) -> Dict[str, int]: 1611 del x["hi"] 1612 return x 1613 1614 python_out = fn(inputs()) 1615 # checkScript reuses the same object, but here it's being mutated so do 1616 # it manually 1617 cu = torch.jit.CompilationUnit() 1618 cu.define(dedent(inspect.getsource(fn))) 1619 self.assertEqual(cu.fn(inputs()), python_out) 1620 self.assertEqual(torch.jit.script(fn)(inputs()), python_out) 1621 with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", 'x["hi"]'): 1622 self.checkScript(fn, [{}]) 1623 1624 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1625 def test_dict_variance(self): 1626 """ 1627 `Dict[T1, _]` is not a subtype of `Dict[T2, _]`, even if `T1` is 1628 a subtype of `T2`; similarly `Dict[_, T1]` would not be a 1629 subtype of `Dict[_, T2]`. 1630 1631 However, if we have a temporary dict object (that is, a dict 1632 comprehension or a dict literal) on the rhs of an assignment 1633 statement, we want to ignore the inferred type of the rhs if we 1634 can prove that: 1) both the lhs and the rhs are dicts with the 1635 same key types (TorchScript has a restricted set of allowed key 1636 types, so we don't need to worry about subtyping relationships 1637 here), and 2) the value type of the dict is a subtype of the 1638 value type of the rhs dict. 1639 """ 1640 1641 def test_dictliteral_is_typed_from_annotation(): 1642 x: Dict[str, Optional[int]] = {"foo": None, "bar": None, "baz": None} 1643 return x 1644 1645 self.checkScript(test_dictliteral_is_typed_from_annotation, ()) 1646 1647 def test_dictcomprehension_is_typed_from_annotation(): 1648 metasyntactics = ["foo", "bar", "baz"] 1649 x: Dict[str, Optional[int]] = { # noqa: C420, RUF025 1650 word: None for word in metasyntactics 1651 } 1652 return x 1653 1654 self.checkScript(test_dictcomprehension_is_typed_from_annotation, ()) 1655 1656 def test_dicts_with_different_value_types_are_invariant(self): 1657 x: Dict[str, int] = {"foo": 1, "bar": 2, "baz": 3} 1658 y: Dict[str, Optional[int]] = x 1659 return x 1660 1661 with self.assertRaisesRegex( 1662 RuntimeError, 1663 "Variable 'y' is " 1664 "annotated with type " 1665 r"Dict\[str, Optional\[int\]\] but " 1666 "is being assigned to a value of " 1667 r"type Dict\[str, int\]", 1668 ): 1669 torch.jit.script(test_dicts_with_different_value_types_are_invariant) 1670 1671 def test_dicts_with_different_value_types_are_invariant_recursive(self): 1672 x: Dict[str, int] = {"foo": 1, "bar": 2, "baz": 3} 1673 y: Dict[str, Dict[str, int]] = {"foo": x, "bar": x, "baz": x} 1674 z: Dict[str, Dict[str, Optional[int]]] = y 1675 return x 1676 1677 with self.assertRaisesRegex( 1678 RuntimeError, 1679 "Variable 'z' is " 1680 "annotated with type " 1681 r"Dict\[str, Dict\[str, Optional" 1682 r"\[int\]\]\] but is being assigned" 1683 r" to a value of type Dict\[str, " 1684 r"Dict\[str, int\]\]", 1685 ): 1686 torch.jit.script( 1687 test_dicts_with_different_value_types_are_invariant_recursive 1688 ) 1689 1690 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1691 def test_keys(self): 1692 @torch.jit.script 1693 def keys(x: Dict[str, Tensor]) -> List[str]: 1694 return list(x.keys()) 1695 1696 self.assertEqual(set(keys(self.dict())), set(self.dict().keys())) 1697 1698 @torch.jit.script 1699 def specialized_list(): 1700 li = {1: 1, 2: 2}.keys() 1701 li.append(3) 1702 return li 1703 1704 self.assertTrue(set(specialized_list()) == {1, 2, 3}) 1705 1706 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1707 def test_values(self): 1708 @torch.jit.script 1709 def values(x: Dict[str, Tensor]) -> List[Tensor]: 1710 return list(x.values()) 1711 1712 the_dict = self.dict() 1713 self.assertEqual(set(values(the_dict)), set(the_dict.values())) 1714 1715 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1716 def test_len(self): 1717 def length(x: Dict[str, Tensor]) -> int: 1718 return len(x) 1719 1720 self.checkScript(length, (self.dict(),)) 1721 1722 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1723 def test_copy(self): 1724 def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]: 1725 return x.copy() 1726 1727 self.checkScript(func, (self.dict(),)) 1728 1729 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1730 def test_items(self): 1731 def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]: 1732 return x.items() 1733 1734 # The value returned by Python is in arbitrary order, so we can't use 1735 # checkScript 1736 scripted_func = torch.jit.script(func) 1737 1738 eager_out = func(self.dict()) 1739 script_out = scripted_func(self.dict()) 1740 1741 self.assertEqual(len(eager_out), len(script_out)) 1742 for item in eager_out: 1743 self.assertTrue(item in script_out) 1744 1745 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1746 def test_pop(self): 1747 def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]: 1748 return x.pop(key), x 1749 1750 # checkScript doesn't copy the inputs, so we can't use it since this mutates 1751 # the dict 1752 def tester(fn, *args): 1753 eager_out = fn(self.dict(), *args) 1754 script_out = torch.jit.script(fn)(self.dict(), *args) 1755 self.assertEqual(eager_out, script_out) 1756 1757 tester(pop, "a") 1758 1759 with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", "x.pop"): 1760 torch.jit.script(pop)(self.dict(), "x") 1761 1762 def default_pop( 1763 x: Dict[str, Tensor], key: str, default: Tensor 1764 ) -> Tuple[Tensor, Dict[str, Tensor]]: 1765 return x.pop(key, default), x 1766 1767 tester(default_pop, "a", torch.randn(2, 2)) 1768 tester(default_pop, "x", torch.randn(2, 2)) 1769 1770 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1771 def test_setdefault(self): 1772 def setdefault( 1773 x: Dict[str, Tensor], key: str, default: Tensor 1774 ) -> Dict[str, Tensor]: 1775 x.setdefault(key, default) 1776 return x 1777 1778 self.checkScript(setdefault, (self.dict(), "a", torch.randn(2, 2))) 1779 self.checkScript(setdefault, (self.dict(), "nonexistant", torch.randn(2, 2))) 1780 1781 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1782 def test_update(self): 1783 def update( 1784 a: Dict[str, Tensor], b: Dict[str, Tensor] 1785 ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: 1786 a.update(b) 1787 return a, b 1788 1789 self.checkScript(update, (self.dict(), self.dict())) 1790 self.checkScript(update, (self.dict(), self.dict2())) 1791 1792 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1793 def test_update_existing_key(self): 1794 def foo() -> Dict[str, int]: 1795 a: Dict[str, int] = {} 1796 for i in range(3): 1797 a.update({"a": i}) 1798 return a 1799 1800 self.checkScript(foo, ()) 1801 1802 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1803 def test_aug_assign(self): 1804 def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]: 1805 a["a"] += 1 1806 a["b"] -= 12 1807 a["c"] *= 122 1808 a["c"] /= 2 1809 a["c"] %= 2 1810 return a 1811 1812 def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]: 1813 a["a"] += 3.4 1814 a["b"] -= 2.4 1815 a["c"] *= 3.0 1816 a["c"] /= 2.0 1817 a["c"] %= 2.0 1818 return a 1819 1820 self.checkScript(aug_assign_dict_tensor, (self.dict(),)) 1821 self.checkScript(aug_assign_dict_prim, ({"a": 3.0, "b": 2.0, "c": 4.0},)) 1822 1823 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1824 def test_popitem(self): 1825 @torch.jit.script 1826 def popitem( 1827 x: Dict[str, Tensor] 1828 ) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]: 1829 item = x.popitem() 1830 return item, x 1831 1832 # The value returned by Python is arbitrary, so we can't use checkScript 1833 eager_in = self.dict() 1834 eager_out = (eager_in.popitem(), eager_in) 1835 1836 script_out = popitem(self.dict()) 1837 1838 # Check that an item was removed 1839 self.assertEqual(len(eager_out[1]), len(script_out[1])) 1840 1841 # Check that the item is the correct types 1842 self.assertTrue(isinstance(script_out[0][0], str)) 1843 self.assertTrue(isinstance(script_out[0][1], torch.Tensor)) 1844 1845 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1846 def test_clear(self): 1847 def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]: 1848 x.clear() 1849 return x 1850 1851 self.checkScript(clear, (self.dict(),)) 1852 1853 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1854 def test_get(self): 1855 def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: 1856 return x.get(key) 1857 1858 self.checkScript(get, (self.dict(), "a")) 1859 self.checkScript(get, (self.dict(), "doesn't exist")) 1860 1861 def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]: 1862 return x.get(key, torch.randn(2, 2)) 1863 1864 self.checkScript(get, (self.dict(), "a")) 1865 self.checkScript(get, (self.dict(), "doesn't exist")) 1866 1867 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1868 def test_get_boolkey(self): 1869 def get(x: Dict[bool, int], key: bool) -> Optional[int]: 1870 return x.get(key) 1871 1872 self.checkScript(get, (self.dict_bool(), True)) 1873 self.checkScript(get, (self.dict_bool(), False)) 1874 1875 def get_default(x: Dict[bool, int], key: bool) -> int: 1876 return x.get(key, 42) 1877 1878 self.checkScript(get_default, (self.dict_bool(), True)) 1879 self.checkScript(get_default, (self.dict_bool(), False)) 1880 1881 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1882 def test_basic(self): 1883 def simple(x: Dict[str, int]) -> Dict[str, int]: 1884 return x 1885 1886 self.checkScript(simple, ({"item": 20, "other_item": 120},)) 1887 1888 def index(x: Dict[str, int]) -> int: 1889 return x["item"] 1890 1891 self.checkScript(index, ({"item": 20, "other_item": 120},)) 1892 1893 def type_default() -> Dict[str, Tensor]: 1894 return {} 1895 1896 self.checkScript(type_default, ()) 1897 1898 @torch.jit.script 1899 def missing_index(x: Dict[str, int]) -> int: 1900 return x["dne"] 1901 1902 with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", 'x["dne"'): 1903 missing_index({"item": 20, "other_item": 120}) 1904 1905 code = dedent( 1906 """ 1907 def literal1(): 1908 return torch.jit.annotate(Dict[int, float], {}) 1909 def literal2(): 1910 return torch.jit.annotate(Dict[int, float], {10: 1.2}) 1911 """ 1912 ) 1913 cu = torch.jit.CompilationUnit(code) 1914 self.assertEqual({}, cu.literal1()) 1915 self.assertEqual({10: 1.2}, cu.literal2()) 1916 1917 cu = torch.jit.CompilationUnit( 1918 dedent( 1919 """ 1920 def literal3(): 1921 return torch.jit.annotate(Dict[int, float], {10: 1.2, 11: 1.3}) 1922 """ 1923 ) 1924 ) 1925 self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3()) 1926 1927 def list_of_dicts() -> List[Dict[str, Tensor]]: 1928 return [{"word": torch.ones(2) + 3}, {"other word": torch.ones(1) + 2}] 1929 1930 self.checkScript(list_of_dicts, ()) 1931 1932 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1933 def test_mutability(self): 1934 @torch.jit.script 1935 def fn() -> Dict[str, int]: 1936 a = torch.jit.annotate(Dict[str, int], {}) 1937 a["ok"] = 10 1938 return a 1939 1940 self.assertEqual(fn(), {"ok": 10}) 1941 1942 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1943 def test_key_type(self): 1944 with self.assertRaisesRegexWithHighlight( 1945 RuntimeError, "but instead found type", "a[None]" 1946 ): 1947 1948 @torch.jit.script 1949 def fn(a: Dict[str, int]) -> int: 1950 return a[None] 1951 1952 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1953 def test_loop(self): 1954 @torch.jit.script 1955 def fn(x: int) -> Dict[str, int]: 1956 a = torch.jit.annotate(Dict[str, int], {}) 1957 for i in range(x): 1958 a["ok"] = i 1959 return a 1960 1961 self.assertEqual(fn(10), {"ok": 9}) 1962 1963 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1964 def test_view(self): 1965 def fn(x, y): 1966 l = {"a": x} 1967 x_view = l["a"] 1968 a = x + x 1969 x_view.add_(y) 1970 b = x + x 1971 return a == b 1972 1973 self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) 1974 1975 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 1976 def test_membership(self): 1977 def fn(x: Dict[int, int], y: int) -> int: 1978 return x.get(y, 3) 1979 1980 d = {1: 2, 3: 4} 1981 self.checkScript(fn, (d, 3)) 1982 self.checkScript(fn, (d, 2)) 1983 1984 def optional(x: Dict[int, int], y: int) -> bool: 1985 res = x.get(y) 1986 return res is None 1987 1988 self.checkScript(fn, (d, 3)) 1989 self.checkScript(fn, (d, 2)) 1990 1991 with self.assertRaisesRegexWithHighlight( 1992 RuntimeError, "is actually of type Optional", "return x.get(y" 1993 ): 1994 1995 @torch.jit.script 1996 def bad_types(x: Dict[int, int], y: int) -> int: 1997 return x.get(y) # noqa: T484 1998 1999 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 2000 def test_dict_to_python(self): 2001 @torch.jit.ignore 2002 def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]: 2003 return [my_dict[k] for k in keys] 2004 2005 def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]: 2006 return python_lookup(my_dict, keys) 2007 2008 a_dict = {"a": torch.ones(1), "b": torch.ones(1) + 1, "c": torch.ones(1) + 2} 2009 self.checkScript(fn, (a_dict, ("a", "c"))) 2010 2011 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 2012 def test_ordered_dict(self): 2013 def test_func(fn, inputs): 2014 self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs)) 2015 2016 def repeated_key(): 2017 return OrderedDict([(1, 2), (2, 3), (1, 4)]) 2018 2019 test_func(repeated_key, ()) 2020 2021 def no_args(): 2022 a = OrderedDict() 2023 a["one"] = torch.tensor(1) 2024 a["two"] = torch.tensor(2) 2025 2026 test_func(no_args, ()) 2027 2028 def test_dict_constructor(): 2029 a = dict() # noqa: C408 2030 a["one"] = torch.tensor(1) 2031 return a, dict([(1, 2), (2, 3), (1, 4)]) # noqa: C406 2032 2033 test_func(test_dict_constructor, ()) 2034 2035 def test_dict_initializer_list(): 2036 a = {"1": torch.tensor(1), "2": torch.tensor(2)} 2037 output_order = [] 2038 for key in a: 2039 output_order.append(a[key]) 2040 return output_order 2041 2042 test_func(test_dict_initializer_list, ()) 2043 2044 def test_dict_error(): 2045 a = dict() # noqa: C408 2046 a[1] = 2 2047 return a 2048 2049 with self.assertRaisesRegexWithHighlight( 2050 Exception, "Arguments for call are not", "a[1] = 2" 2051 ): 2052 torch.jit.script(test_dict_error) 2053 2054 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 2055 def test_type_annotation_missing_contained_type(self): 2056 """ 2057 Test that the use of a Dict type annotation without contained 2058 key and value types produces an error. 2059 """ 2060 2061 # This function uses a type comment. 2062 def fn_with_comment(input: Dict) -> Any: 2063 return input 2064 2065 # This function uses Python3 style type annotations. 2066 def annotated_fn(input: Dict) -> Any: 2067 return input 2068 2069 with self.assertRaisesRegex( 2070 RuntimeError, r"Attempted to use Dict without contained types" 2071 ): 2072 cu = torch.jit.CompilationUnit() 2073 cu.define(dedent(inspect.getsource(fn_with_comment))) 2074 2075 with self.assertRaisesRegex( 2076 RuntimeError, r"Attempted to use Dict without contained types" 2077 ): 2078 cu = torch.jit.CompilationUnit() 2079 cu.define(dedent(inspect.getsource(annotated_fn))) 2080 2081 with self.assertRaisesRegex( 2082 RuntimeError, r"Attempted to use Dict without contained types" 2083 ): 2084 m = torch.jit.script(fn_with_comment) 2085 2086 with self.assertRaisesRegex( 2087 RuntimeError, r"Attempted to use Dict without contained types" 2088 ): 2089 m = torch.jit.script(annotated_fn) 2090 2091 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 2092 def test_dict_preserves_order(self): 2093 def dict_ordering(): 2094 a: Dict[int, int] = {} 2095 for i in range(1000): 2096 a[i] = i + 1 2097 return a 2098 2099 self.checkScript(dict_ordering, ()) 2100 di = torch.jit.script(dict_ordering)() 2101 res = list(di.items()) 2102 for i in range(1000): 2103 key, value = res[i] 2104 self.assertTrue(key == i and value == i + 1) 2105 2106 @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason") 2107 def test_optional_dict_construct(self): 2108 class M(torch.nn.Module): 2109 def use(self, buffer: Dict[str, Optional[torch.Tensor]]): 2110 return buffer["prev_key"] 2111 2112 def forward(self, x): 2113 prev_key = torch.rand(2, 3) 2114 next_key = torch.rand(2, 3) 2115 saved_state: Dict[str, Optional[torch.Tensor]] = { 2116 "prev_key": prev_key, 2117 "next_key": next_key, 2118 } 2119 2120 return self.use(saved_state) 2121 2122 self.checkModule(M(), (torch.rand(2, 2),)) 2123 2124 2125class TestNamedTuple(JitTestCase): 2126 def test_namedtuple(self): 2127 class FeatureVector(NamedTuple): 2128 float_features: float 2129 sequence_features: List[float] 2130 time_since_first: float 2131 2132 @torch.jit.script 2133 def foo(x) -> float: 2134 fv = FeatureVector(3.0, [3.0], 3.0) 2135 rv = fv.float_features 2136 for val in fv.sequence_features: 2137 rv += val 2138 rv *= fv.time_since_first 2139 return rv 2140 2141 self.assertEqual(foo(torch.rand(3, 4)), 18.0) 2142 2143 def test_namedtuple_constant(self): 2144 class Tup(NamedTuple): 2145 a: int 2146 b: int 2147 2148 @torch.jit.script 2149 def foo(): 2150 return Tup(1, 2) 2151 2152 self.assertEqual(foo(), Tup(1, 2)) 2153 2154 def test_return_named_tuple(self): 2155 class FeatureVector(NamedTuple): 2156 float_features: float 2157 sequence_features: List[float] 2158 time_since_first: float 2159 2160 @torch.jit.script 2161 def foo(x): 2162 fv = FeatureVector(3.0, [3.0], 3.0) 2163 return fv 2164 2165 out = foo(torch.rand(3, 4)) 2166 out = foo(torch.rand(3, 4)) 2167 self.assertEqual(out.float_features, 3.0) 2168 self.assertEqual(out.sequence_features, [3.0]) 2169 self.assertEqual(out.time_since_first, 3.0) 2170 2171 def test_namedtuple_as_attr(self): 2172 class Config(NamedTuple): 2173 size: int 2174 2175 class MyMod(nn.Module): 2176 configs: Dict[int, Config] 2177 2178 def __init__(self, configs): 2179 super().__init__() 2180 self.configs = configs 2181 2182 def forward(self, x): 2183 for config in self.configs.values(): 2184 x += config.size 2185 return x 2186 2187 s = torch.jit.script(MyMod({0: Config(size=16)})) 2188 2189 def test_namedtuple_resolution(self): 2190 class TheType(NamedTuple): 2191 t: int 2192 2193 class MyModule(types.ModuleType): 2194 def __init__(self) -> None: 2195 super().__init__("MyModule") 2196 2197 def __getattr__(self, attr): 2198 return TheType 2199 2200 some_module = MyModule() 2201 2202 def fn() -> some_module.Type: 2203 return some_module.Type(1) 2204 2205 self.checkScript(fn, []) 2206 2207 def test_namedtuple_slice_unpack(self): 2208 class MyCoolNamedTuple(NamedTuple): 2209 a: int 2210 b: float 2211 c: List[int] 2212 2213 @torch.jit.script 2214 def foo(a: int, b: float, c: List[int]): 2215 tup = MyCoolNamedTuple(a, b, c) 2216 my_a, my_b, my_c = tup 2217 return tup[:1], my_a, my_c 2218 2219 self.assertEqual(foo(3, 3.5, [6]), ((3,), 3, [6])) 2220 2221 def test_namedtuple_lower(self): 2222 class MyCoolNamedTuple(NamedTuple): 2223 a: int 2224 b: float 2225 c: List[int] 2226 2227 @torch.jit.script 2228 def foo(a: int): 2229 tup = MyCoolNamedTuple(a, 3.14, [9]) 2230 return tup 2231 2232 FileCheck().check("TupleConstruct").run(foo.graph) 2233 torch._C._jit_pass_lower_all_tuples(foo.graph) 2234 FileCheck().check_not("TupleConstruct").run(foo.graph) 2235 2236 def test_namedtuple_type_annotation(self): 2237 global MyCoolNamedTuple # see [local resolution in python] 2238 2239 class MyCoolNamedTuple(NamedTuple): 2240 a: int 2241 b: float 2242 c: List[int] 2243 2244 @torch.jit.script 2245 def foo(x: MyCoolNamedTuple) -> MyCoolNamedTuple: 2246 return x 2247 2248 mnt = MyCoolNamedTuple(42, 420.0, [666]) 2249 self.assertEqual(foo(mnt), mnt) 2250 2251 def test_namedtuple_wrong_types(self): 2252 class MyCoolNamedTuple(NamedTuple): 2253 a: int 2254 b: float 2255 c: List[int] 2256 2257 with self.assertRaisesRegex( 2258 RuntimeError, 2259 "Expected a value of type 'int' for argument 'a'" 2260 " but instead found type 'str'", 2261 ): 2262 2263 @torch.jit.script 2264 def foo(): 2265 tup = MyCoolNamedTuple("foo", "bar", "baz") 2266 return tup 2267 2268 def test_namedtuple_kwarg_construct(self): 2269 class MyCoolNamedTuple(NamedTuple): 2270 a: int 2271 b: float 2272 c: List[int] 2273 2274 @torch.jit.script 2275 def foo(): 2276 tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) 2277 return tup 2278 2279 tup = foo() 2280 self.assertEqual(tup.a, 9) 2281 self.assertEqual(tup.b, 3.5) 2282 self.assertEqual(tup.c, [1, 2, 3]) 2283 2284 @unittest.skipIf(True, "broken while these tests were not in CI") 2285 def test_namedtuple_serialization(self): 2286 class MyCoolNamedTuple(NamedTuple): 2287 a: int 2288 b: float 2289 c: List[int] 2290 2291 class MyMod(torch.jit.ScriptModule): 2292 @torch.jit.script_method 2293 def forward(self): 2294 return MyCoolNamedTuple(3, 3.5, [3, 4, 5]) 2295 2296 mm = MyMod() 2297 mm.save("foo.zip") 2298 torch.testing._internal.jit_utils.clear_class_registry() 2299 loaded = torch.jit.load("foo.zip") 2300 2301 out = mm() 2302 out_loaded = loaded() 2303 2304 for name in ["a", "b", "c"]: 2305 self.assertEqual(getattr(out_loaded, name), getattr(out, name)) 2306 2307 def test_namedtuple_inside_forwardref(self): 2308 class FeatureVector(NamedTuple): 2309 float_features: "float" 2310 sequence_features: "List[float]" 2311 time_since_first: "float" 2312 2313 @torch.jit.script 2314 def foo(x) -> float: 2315 fv = FeatureVector(3.0, [3.0], 3.0) 2316 rv = fv.float_features 2317 for val in fv.sequence_features: 2318 rv += val 2319 rv *= fv.time_since_first 2320 return rv 2321 2322 self.assertEqual(foo(torch.rand(3, 4)), 18.0) 2323 2324 def test_namedtuple_input_forwardref(self): 2325 class MyNamedTuple(NamedTuple): 2326 a: "int" 2327 b: "float" 2328 c: "torch.Tensor" 2329 2330 make_global(MyNamedTuple) 2331 2332 nt = MyNamedTuple(4, 2.5, torch.rand((2, 2))) 2333 2334 def fn(obj: MyNamedTuple): 2335 return ((obj.c + obj.b) ** obj.a).sin() 2336 2337 expected = fn(nt) 2338 fn_s = torch.jit.script(fn) 2339 actual = fn_s(nt) 2340 self.assertEqual(expected, actual) 2341 2342 # see #95858 2343 @unittest.expectedFailure 2344 def test_namedtuple_resolution_forwardref(self): 2345 class TheType(NamedTuple): 2346 t: "int" 2347 2348 class MyModule(types.ModuleType): 2349 def __init__(self) -> None: 2350 super().__init__("MyModule") 2351 2352 def __getattr__(self, attr): 2353 return TheType 2354 2355 some_module = MyModule() 2356 2357 def fn() -> some_module.Type: 2358 return some_module.Type(1) 2359 2360 self.checkScript(fn, []) 2361 2362 2363class TestScriptDict(JitTestCase): 2364 """ 2365 This class contains a suite of tests for torch.jit.script, a 2366 function that returns a dictionary-like object that has reference 2367 semantics across the Python/TorchScript boundary. That is, 2368 it can be passed to a TorchScript function that mutates it 2369 and those modifications are visible in the scope of the Python 2370 caller of said TorchScript function. 2371 2372 The vast majority of tests are for making sure that objects returned 2373 by torch.jit.script behave like dictionaries do so that they are fungible 2374 in almost all cirumstances with regular dictionaries. 2375 """ 2376 2377 def _script_dict_add(self, d: torch._C.ScriptDict, k: int, v: int): 2378 """ 2379 This is a helper function that inserts the pair (k, v) into the 2380 dictionary d in TorchScript. It is used for testing reference 2381 semantics. 2382 """ 2383 2384 @torch.jit.script 2385 def dict_add(d: Dict[int, int], k: int, v: int): 2386 d[k] = v 2387 2388 dict_add(d, k, v) 2389 2390 def _compare_eager_and_script(self, fn, input_dict, script_input_dict=None): 2391 """ 2392 This is a helper function that facilitates comparing behaviour between 2393 Python dictionaries and "scripted" dictionaries. 2394 2395 Args: 2396 fn: The function to test and compare the behaviour of. 2397 input_dict: The input dictionary to use for the test (passed to fn). 2398 script_input_dict: The scripted input dictionary to use for the tests. 2399 If None, input_dict is scripted with torch.jit.script 2400 and used instead. 2401 """ 2402 # Create ScriptDict version of input_dict if needed. 2403 script_input_dict = script_input_dict or torch.jit.script(input_dict) 2404 2405 # Run fn with both input_dict and scripted_dict. 2406 eager_raised, script_raised = False, False 2407 2408 try: 2409 eager_out = fn(input_dict) 2410 except Exception as e: 2411 eager_exception = e 2412 eager_raised = True 2413 2414 try: 2415 script_out = fn(script_input_dict) 2416 except Exception as e: 2417 script_exception = e 2418 script_raised = True 2419 2420 # Check that both calls raised or none of them raised. 2421 self.assertEqual(eager_raised, script_raised) 2422 2423 if eager_raised: 2424 # If fn raised an exception, it should be the same between 2425 # regular and scripted dictionaries. 2426 self.assertEqual(type(eager_exception), type(script_exception)) 2427 else: 2428 # Otherwise, make sure the outputs match and the dictionaries 2429 # match (the latter may not be the same as the output). 2430 self.assertEqual(eager_out, script_out) 2431 self.assertEqual(input_dict, script_input_dict) 2432 2433 def test_repr(self): 2434 """ 2435 Test the __repr__ method. 2436 """ 2437 self._compare_eager_and_script(lambda d: repr(d), {1: 2}) 2438 2439 def test_bool(self): 2440 """ 2441 Test the __bool__ method. This should return True 2442 if the dictionary is non-empty and False otherwise. 2443 """ 2444 self._compare_eager_and_script(lambda d: bool(d), {1: 2}) 2445 self._compare_eager_and_script(lambda d: bool(d), {}) 2446 2447 def test_iter(self): 2448 """ 2449 Test iteration over a dictionary's keys. 2450 """ 2451 2452 def sum_keys(input_dict): 2453 s = 0 2454 for k in input_dict: 2455 s += k 2456 2457 return s 2458 2459 self._compare_eager_and_script(sum_keys, {1: 2, 3: 4}) 2460 2461 def test_items(self): 2462 """ 2463 Test .items(). 2464 """ 2465 2466 def sum_pair_product(input_dict): 2467 s = 0 2468 for k, v in input_dict.items(): 2469 s += k * v 2470 2471 return s 2472 2473 self._compare_eager_and_script(sum_pair_product, {1: 2, 3: 4}) 2474 2475 def test_getitem(self): 2476 """ 2477 Test accessing dictionary values using the [] operator. 2478 """ 2479 data = {1: 2, 3: 4} 2480 self._compare_eager_and_script(lambda d: d[1], data) 2481 self._compare_eager_and_script(lambda d: d[4], data) 2482 self._compare_eager_and_script(lambda d: d[2], data) 2483 self._compare_eager_and_script(lambda d: d["key"], data) 2484 2485 def test_setitem(self): 2486 """ 2487 Test setting dictionary values using the [] operator. 2488 """ 2489 data = {1: 2, 3: 4} 2490 2491 def fn(input_dict): 2492 input_dict[1] = 10 2493 input_dict[3] = 11 2494 2495 self._compare_eager_and_script(fn, data) 2496 2497 # Check that using improperly typed keys and values 2498 # throws TypeError. 2499 # _compare_eager_and_script cannot be used here since 2500 # the following uses of __setitem__ are valid in 2501 # Python. 2502 script_data = torch.jit.script(data) 2503 2504 with self.assertRaises(TypeError): 2505 script_data["str"] = 3 2506 2507 with self.assertRaises(TypeError): 2508 script_data[3] = "str" 2509 2510 def test_contains(self): 2511 """ 2512 Test membership checks (x in y, x not in y). 2513 """ 2514 data = {1: 2, 3: 4} 2515 2516 def fn(input_dict): 2517 return ( 2518 1 in input_dict, 2519 2 not in input_dict, 2520 3 in input_dict, 2521 4 not in input_dict, 2522 ) 2523 2524 self._compare_eager_and_script(fn, data) 2525 2526 # Check that using an improperly typed key 2527 # throws KeyError. 2528 script_data = torch.jit.script(data) 2529 2530 with self.assertRaises(KeyError): 2531 a = "str" in script_data 2532 2533 def test_delitem(self): 2534 """ 2535 Test deletion. 2536 """ 2537 data = {1: 2, 3: 4} 2538 2539 def del_fn(input_dict): 2540 del input_dict[1] 2541 2542 def del_fn_raises(input_dict): 2543 del input_dict[10] 2544 2545 self._compare_eager_and_script(del_fn, data) 2546 self._compare_eager_and_script(del_fn_raises, data) 2547 2548 # Check that using an improperly typed key 2549 # throws TypeError. 2550 script_data = torch.jit.script(data) 2551 2552 with self.assertRaises(TypeError): 2553 del script_data["str"] 2554 2555 def test_len(self): 2556 """ 2557 Test len() builtin function. 2558 """ 2559 self._compare_eager_and_script(lambda d: len(d), {1: 2}) 2560 self._compare_eager_and_script(lambda d: len(d), {}) 2561 2562 @unittest.skip( 2563 "Cannot pass until all dicts returned from TorchScript are ScriptDicts" 2564 ) 2565 def test_nested(self): 2566 """ 2567 Test that reference semantics are honoured when the ScriptDict that is 2568 mutated using TorchScript is inside another. 2569 """ 2570 nested = torch.jit.script( 2571 {1: {1: 2}, 2: {3: 4}}, type_hint=Dict[int, Dict[int, int]] 2572 ) 2573 2574 one = nested[1] 2575 two = nested[2] 2576 2577 self._script_dict_add(one, 9, 10) 2578 self._script_dict_add(two, 11, 12) 2579 2580 # The mutation should be visible in the original dictionary, nested. 2581 self.assertEqual(len(one), 2) 2582 self.assertEqual(len(two), 2) 2583 self.assertEqual(len(nested[1]), 2) 2584 self.assertEqual(len(nested[2]), 2) 2585 2586 def test_reference_semantics(self): 2587 """ 2588 Test that reference semantics are honoured; that modifications made 2589 to a ScriptDict in TorchScript are visible in Python. 2590 """ 2591 data = torch.jit.script({1: 2}) 2592 self._script_dict_add(data, 3, 4) 2593 2594 # The mutation should be visible in the original dictionary. 2595 self.assertEqual(len(data), 2) 2596 self.assertTrue(3 in data) 2597 self.assertEqual(data[3], 4) 2598 2599 2600class TestScriptList(JitTestCase): 2601 """ 2602 This class contains a suite of tests for torch._C.ScriptList, a 2603 function that returns a list-like object that has reference 2604 semantics across the Python/TorchScript boundary. That is, 2605 it can be passed to a TorchScript function that mutates it 2606 and those modifications are visible in the scope of the Python 2607 caller of said TorchScript function. 2608 2609 The vast majority of tests are for making sure that instances of 2610 torch._C.ScriptList behave like lists do so that they are fungible 2611 in almost all cirumstances with regular list. 2612 """ 2613 2614 def _script_list_add(self, l: torch._C.ScriptList, e: int): 2615 """ 2616 This is a helper function that inserts the element e into the 2617 list l in TorchScript. It is used for testing reference 2618 semantics. 2619 """ 2620 2621 @torch.jit.script 2622 def list_add(l: List[int], e: int): 2623 l.append(e) 2624 2625 list_add(l, e) 2626 2627 def _compare_eager_and_script(self, fn, input_list, script_input_list=None): 2628 """ 2629 This is a helper function that facilitates comparing behaviour between 2630 Python lists and "scripted" lists. 2631 Args: 2632 fn: The function to test and compare the behaviour of. 2633 input_list: The input list to use for the test (passed to fn). 2634 script_input_list: The scripted input list to use for the tests. 2635 If None, input_list is scripted with torch.jit.script 2636 and used instead. 2637 """ 2638 # Create ScriptDict version of input_list if needed. 2639 script_input_list = script_input_list or torch.jit.script(input_list) 2640 2641 # Run fn with both input_list and scripted_dict. 2642 eager_raised, script_raised = False, False 2643 2644 try: 2645 eager_out = fn(input_list) 2646 except Exception as e: 2647 eager_exception = e 2648 eager_raised = True 2649 2650 try: 2651 script_out = fn(script_input_list) 2652 except Exception as e: 2653 script_exception = e 2654 script_raised = True 2655 2656 # Check that both calls raised or none of them raised. 2657 self.assertEqual(eager_raised, script_raised) 2658 2659 if eager_raised: 2660 # If fn raised an exception, it should be the same between 2661 # regular and scripted lists. 2662 self.assertEqual(type(eager_exception), type(script_exception)) 2663 else: 2664 # Otherwise, make sure the outputs match and the lists 2665 # match (the latter may not be the same as the output). 2666 self.assertEqual(eager_out, script_out) 2667 self.assertEqual(input_list, script_input_list) 2668 2669 def test_repr(self): 2670 """ 2671 Test the __repr__ method. 2672 """ 2673 self._compare_eager_and_script(lambda l: repr(l), [1]) 2674 2675 def test_bool(self): 2676 """ 2677 Test the __bool__ method. This should return True 2678 if the list is non-empty and False otherwise. 2679 """ 2680 self._compare_eager_and_script(lambda l: bool(l), [1]) 2681 self._compare_eager_and_script(lambda l: bool(l), []) 2682 2683 def test_iter(self): 2684 """ 2685 Test iteration over a list's elements. 2686 """ 2687 2688 def sum_elements(input_list): 2689 s = 0 2690 for k in input_list: 2691 s += k 2692 2693 return s 2694 2695 self._compare_eager_and_script(sum_elements, [1, 2, 3, 4]) 2696 2697 def test_getitem(self): 2698 """ 2699 Test accessing list elements using the [] operator. 2700 """ 2701 data = [1, 2, 3, 4] 2702 2703 # Test regular indexing. 2704 self._compare_eager_and_script(lambda l: l[1], data) 2705 self._compare_eager_and_script(lambda l: l[3], data) 2706 self._compare_eager_and_script(lambda l: l[-1], data) 2707 2708 # Test slicing. 2709 self._compare_eager_and_script(lambda l: l[1:3], data) 2710 self._compare_eager_and_script(lambda l: l[:], data) 2711 self._compare_eager_and_script(lambda l: l[1:], data) 2712 self._compare_eager_and_script(lambda l: l[:2], data) 2713 self._compare_eager_and_script(lambda l: l[-1], data) 2714 self._compare_eager_and_script(lambda l: l[-1::-1], data) 2715 2716 # Test errors. 2717 self._compare_eager_and_script(lambda l: l[5], data) 2718 self._compare_eager_and_script(lambda l: l[-7], data) 2719 self._compare_eager_and_script(lambda l: l["key"], data) 2720 2721 def test_setitem(self): 2722 """ 2723 Test setting list elements using the [] operator. 2724 """ 2725 data = [1, 2, 3, 4] 2726 2727 # Test regular assignment. 2728 def setitem(input_list): 2729 input_list[1] = 10 2730 input_list[3] = 11 2731 input_list[-1] = 12 2732 2733 self._compare_eager_and_script(setitem, data.copy()) 2734 2735 # Test slice assignment. 2736 # TODO: Something like input_list[:1] = [1, 2, 3, 4, 5] 2737 # is allowed in Python, but pybind11/stl_bind.h does not 2738 # allow it. Should we? 2739 def setitem_slice(input_list): 2740 input_list[:4:2] = [10, 11] 2741 input_list[-2:] = [15, 16] 2742 2743 self._compare_eager_and_script(setitem_slice, data) 2744 2745 # Test errors. 2746 def out_of_range(input_list): 2747 input_list[11] = 3 2748 2749 def out_of_range_negative(input_list): 2750 input_list[-11] = 3 2751 2752 def wrong_index_type(input_list): 2753 input_list["str"] = 3 2754 2755 self._compare_eager_and_script(out_of_range, data) 2756 self._compare_eager_and_script(out_of_range_negative, data) 2757 self._compare_eager_and_script(wrong_index_type, data) 2758 2759 # Check that using value of an incorrect type throws TypeError. 2760 # _compare_eager_and_script cannot be used here since 2761 # the following use of __setitem__ is valid in 2762 # Python. 2763 script_data = torch.jit.script(data) 2764 2765 with self.assertRaises(TypeError): 2766 script_data[0] = "str" 2767 2768 def test_contains(self): 2769 """ 2770 Test membership checks (x in y, x not in y). 2771 """ 2772 data = [1, 2, 3, 4] 2773 2774 def fn(input_list): 2775 return ( 2776 1 in input_list, 2777 2 not in input_list, 2778 3 in input_list, 2779 4 not in input_list, 2780 ) 2781 2782 self._compare_eager_and_script(fn, data) 2783 2784 # Check that using a value of an incorrect type throws a TypeError. 2785 script_data = torch.jit.script(data) 2786 2787 with self.assertRaises(TypeError): 2788 a = "str" in script_data 2789 2790 def test_delitem(self): 2791 """ 2792 Test deletion. 2793 """ 2794 data = [1, 2, 3, 4] 2795 2796 def del_fn(input_list): 2797 del input_list[1] 2798 2799 def del_fn_out_of_range(input_list): 2800 del input_list[10] 2801 2802 def del_fn_wrong_type(input_list): 2803 del input_list["str"] 2804 2805 self._compare_eager_and_script(del_fn, data.copy()) 2806 self._compare_eager_and_script(del_fn_out_of_range, data) 2807 self._compare_eager_and_script(del_fn_wrong_type, data) 2808 2809 def test_len(self): 2810 """ 2811 Test len() builtin function. 2812 """ 2813 self._compare_eager_and_script(lambda l: len(l), [1, 2, 3, 4]) 2814 self._compare_eager_and_script(lambda l: len(l), []) 2815 2816 def test_count(self): 2817 """ 2818 Test count method. 2819 """ 2820 self._compare_eager_and_script(lambda l: l.count(3), [1, 2, 3, 3]) 2821 2822 # Check that using a value of an incorrect type throws TypeError. 2823 script_data = torch.jit.script([1]) 2824 2825 with self.assertRaises(TypeError): 2826 script_data.count("str") 2827 2828 def test_remove(self): 2829 """ 2830 Test remove method. 2831 """ 2832 self._compare_eager_and_script(lambda l: l.remove(1), [1, 2, 3]) 2833 self._compare_eager_and_script(lambda l: l.remove(10), [1, 2, 3]) 2834 2835 # Check that using a value of an incorrect type throws TypeError. 2836 script_data = torch.jit.script([1]) 2837 2838 with self.assertRaises(TypeError): 2839 script_data.remove("str") 2840 2841 def test_append(self): 2842 """ 2843 Test append method. 2844 """ 2845 self._compare_eager_and_script(lambda l: l.append(1), [4, 3, 2]) 2846 2847 # Check that using a value of an incorrect type throws TypeError. 2848 script_data = torch.jit.script([1]) 2849 2850 with self.assertRaises(TypeError): 2851 script_data.append("str") 2852 2853 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 2854 def test_clear(self): 2855 """ 2856 Test clear. 2857 """ 2858 self._compare_eager_and_script(lambda l: l.clear(), [4, 3, 2]) 2859 2860 def test_extend(self): 2861 """ 2862 Test extend. 2863 """ 2864 2865 class Iterable: 2866 def __init__(self, limit: int): 2867 self.limit = limit 2868 self.value = 0 2869 2870 def __iter__(self): 2871 return self 2872 2873 def __next__(self): 2874 if self.value == limit: # noqa: F821 2875 raise StopIteration 2876 2877 ret = self.value 2878 self.value += 1 2879 return ret 2880 2881 data = [1, 2, 3] 2882 2883 def extend_list(input_list): 2884 input_list.extend([4, 5, 6]) 2885 2886 def extend_dict(input_list): 2887 input_list.extend({4: 10, 5: 11, 6: 12}) 2888 2889 def extend_iterable(input_list): 2890 input_list.extend(Iterable(3)) 2891 2892 self._compare_eager_and_script(extend_list, data.copy()) 2893 self._compare_eager_and_script(extend_dict, data.copy()) 2894 self._compare_eager_and_script(extend_iterable, data) 2895 2896 # Check that using a value of an incorrect type throws TypeError. 2897 script_data = torch.jit.script([1]) 2898 2899 with self.assertRaises(TypeError): 2900 script_data.extend(["a"]) 2901 2902 with self.assertRaises(TypeError): 2903 script_data.extend({"a": 1}) 2904 2905 def test_insert(self): 2906 """ 2907 Test insert. 2908 """ 2909 data = [1, 2, 4] 2910 2911 self._compare_eager_and_script(lambda l: l.insert(3, 3), data.copy()) 2912 self._compare_eager_and_script(lambda l: l.insert(0, 3), data.copy()) 2913 self._compare_eager_and_script(lambda l: l.insert(-2, 3), data) 2914 2915 # Check that using a value of an incorrect type throws TypeError. 2916 script_data = torch.jit.script([1]) 2917 2918 with self.assertRaises(TypeError): 2919 script_data.insert((0, "str")) 2920 2921 def test_pop(self): 2922 """ 2923 Test pop. 2924 """ 2925 data = [1, 2, 3, 4, 5] 2926 2927 # Test normal cases. 2928 self._compare_eager_and_script(lambda l: l.pop(), data.copy()) 2929 self._compare_eager_and_script(lambda l: l.pop(2), data.copy()) 2930 self._compare_eager_and_script(lambda l: l.pop(-3), data.copy()) 2931 2932 # Test error cases. 2933 self._compare_eager_and_script(lambda l: l.pop(10), data) 2934 2935 @unittest.skip( 2936 "Cannot pass until all list returned from TorchScript are ScriptLists" 2937 ) 2938 def test_nested(self): 2939 """ 2940 Test that reference semantics are honoured when the ScriptList that is 2941 mutated using TorchScript is inside another. 2942 """ 2943 nested = torch.jit.script([[1], [2]], List[List[int]]) 2944 2945 one = nested[0] 2946 two = nested[1] 2947 2948 self._script_list_add(one, 3) 2949 self._script_list_add(two, 4) 2950 2951 # The mutation should be visible in the original list, nested. 2952 self.assertEqual(len(one), 2) 2953 self.assertEqual(len(two), 2) 2954 self.assertEqual(one[len(one) - 1], 3) 2955 self.assertEqual(two[len(one) - 1], 4) 2956 self.assertEqual(len(nested[0]), 2) 2957 self.assertEqual(len(nested[1]), 2) 2958 2959 def test_reference_semantics(self): 2960 """ 2961 Test that reference semantics are honoured; that modifications made 2962 to a ScriptList in TorchScript are visible in Python. 2963 """ 2964 l = torch.jit.script([1, 2]) 2965 self._script_list_add(l, 3) 2966 2967 self.assertEqual(len(l), 3) 2968 self.assertTrue(3 in l) 2969 self.assertEqual(l[2], 3) 2970 2971 def test_defaultdict(self): 2972 def get_dict(): 2973 test_dict = defaultdict(list) 2974 return test_dict 2975 2976 class Test(torch.nn.Module): 2977 segments_groupby_col: Dict[str, List[str]] 2978 2979 def __init__(self) -> None: 2980 super().__init__() 2981 self.segments_groupby_col = get_dict() 2982 self.col1 = "a" 2983 self.col2 = "b" 2984 2985 def forward(self): 2986 if self.col1 in self.segments_groupby_col.keys(): 2987 return 1 2988 else: 2989 return 2 2990 2991 test = Test() 2992 test_script = torch.jit.script(test) 2993 test_script.segments_groupby_col 2994 2995 # Smoketest for flakiness. Takes around 2s. 2996 for i in range(300): 2997 test = Test() 2998 test_script = torch.jit.script(test) 2999