1# Owner(s): ["oncall: jit"] 2 3import io 4import os 5import sys 6import unittest 7from enum import Enum 8from textwrap import dedent 9from typing import Dict, List, Optional, Tuple, Union 10 11import torch 12from torch.testing import FileCheck 13 14 15# Make the helper files in test/ importable 16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 17sys.path.append(pytorch_test_dir) 18from torch.testing._internal.jit_utils import JitTestCase, make_global 19 20 21if __name__ == "__main__": 22 raise RuntimeError( 23 "This test file is not meant to be run directly, use:\n\n" 24 "\tpython test/test_jit.py TESTNAME\n\n" 25 "instead." 26 ) 27 28 29@unittest.skipIf(sys.version_info < (3, 10), "Requires Python 3.10") 30class TestUnion(JitTestCase): 31 """ 32 This class tests the functionality of `Union`. 33 34 Note: It's important to be able to refine the type of a `Union` to 35 one of its internal types. Currently, there are differences in the 36 way Python expects `isinstance` checks and the way TorchScript 37 expects `isinstance` checks. This means that we can't use 38 `checkScript` in our test cases because either the eager mode or the 39 script mode wouldn't run! So, some test cases have separate but 40 equivalent functions to emulate `checkScript`. 41 """ 42 43 def test_check_union_annotation(self): 44 def test_func(a: int | float, b: Optional[int]): 45 return 0 46 47 scripted_func = torch.jit.script(test_func) 48 graph_rep = str(scripted_func.graph) 49 code_rep = str(scripted_func.code) 50 # TS graph IR for Union should be annotated as Union() 51 FileCheck().check("Union(").check("int?").run(graph_rep) 52 # Serialized code for Union should be annotated as Union[] 53 FileCheck().check("Union[").check("Optional[int]").run(code_rep) 54 self.checkScript(test_func, (5, 6)) 55 # this shouldn't error out 56 torch._C.parse_ir(str(scripted_func.graph)) 57 58 def test_union_with_scalar_values(self): 59 def fn(x: int | float) -> str: 60 return "foo" 61 62 self.checkScript(fn, (1,)) 63 self.checkScript(fn, (1.0,)) 64 65 scripted = torch.jit.script(fn) 66 67 with self.assertRaisesRegex( 68 RuntimeError, 69 "Expected a member of" 70 r" Union\[float, int\] but " 71 "instead found type str", 72 ): 73 scripted("1") 74 75 def test_union_with_collections(self): 76 def fn(x: Dict[str, int] | List[int]) -> str: 77 return "foo" 78 79 self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) 80 self.checkScript(fn, ([1, 2, 3],)) 81 82 scripted = torch.jit.script(fn) 83 84 with self.assertRaisesRegex( 85 RuntimeError, 86 "Expected a member of" 87 r" Union\[List\[int\], Dict\[str, " 88 r"int\]\] but instead found type " 89 r"Dict\[str, str\]", 90 ): 91 scripted({"foo": "bar", "baz": "qux"}) 92 93 with self.assertRaisesRegex( 94 RuntimeError, 95 "Expected a member of" 96 r" Union\[List\[int\], Dict\[str, " 97 r"int\]\] but instead found type " 98 r"List\[str\]", 99 ): 100 scripted(["foo", "bar", "baz"]) 101 102 with self.assertRaisesRegex( 103 RuntimeError, 104 "Expected a member of" 105 r" Union\[List\[int\], Dict\[str, " 106 r"int\]\] but instead found type " 107 "str", 108 ): 109 scripted("1") 110 111 def test_union_with_enum(self): 112 class Color(Enum): 113 RED = 1 114 GREEN = 2 115 116 make_global(Color) 117 118 def fn(x: str | Color) -> str: 119 return "foo" 120 121 self.checkScript(fn, (Color.RED,)) 122 self.checkScript(fn, ("red",)) 123 124 scripted = torch.jit.script(fn) 125 126 with self.assertRaisesRegex( 127 RuntimeError, 128 "Expected a member of" 129 r" Union\[__torch__.jit.test_union_pep604." 130 r"Color, str\] but instead found " 131 "type int", 132 ): 133 scripted(1) 134 135 def test_union_in_class_constructor(self): 136 @torch.jit.script # noqa: B903 137 class A: # noqa: B903 138 def __init__(self, x: int | str) -> None: 139 self.x = x 140 141 def fn(x: str | int) -> A: 142 return A(x) 143 144 self.assertEqual(fn("foo").x, "foo") 145 self.assertEqual(fn(1).x, 1) 146 147 scripted = torch.jit.script(fn) 148 149 with self.assertRaisesRegex( 150 RuntimeError, 151 "Expected a member of" 152 r" Union\[int, str\] but instead " 153 r"found type List\[str\]", 154 ): 155 scripted(["foo", "bar", "baz"]) 156 157 def test_union_return_type(self): 158 def fn(x: int) -> int | str: 159 return "foo" 160 161 self.checkScript(fn, (1,)) 162 163 def test_union_as_annotation(self): 164 def fn() -> int | str: 165 x: int | str = "foo" 166 return x 167 168 self.checkScript(fn, ()) 169 170 def test_union_as_annotation_in_typed_container(self): 171 def fn() -> None: 172 l: List[int | str] = [] 173 u1: int | str = "foo" 174 u2: int | str = 1 175 l.append(u1) 176 l.append(u2) 177 178 self.checkScript(fn, ()) 179 180 def test_union_as_annotation_py2(self): 181 def fn(): 182 # type: () -> int | str 183 x: int | str = "foo" 184 return x 185 186 self.checkScript(fn, ()) 187 188 def test_union_as_internal_tuple_type(self): 189 def fn(): 190 t: Tuple[int | str, int | str] = (1, "foo") 191 return t 192 193 self.checkScript(fn, ()) 194 195 def test_union_variable_can_be_reassigned(self): 196 @torch.jit.script 197 def aux1(i: int): 198 return int(i**2) 199 200 @torch.jit.script 201 def aux2(s: str): 202 return s + s 203 204 def fn() -> int | str: 205 x: int | str = "foo" 206 i: int = 1 207 x = i 208 y: int = aux1(x) 209 z: str = aux2(str(y)) 210 x = z 211 return x 212 213 self.checkScript(fn, ()) 214 215 def test_union_does_not_replace_existing_annotated_type(self): 216 def fn(): 217 x: List[int] = [1, 2, 3] 218 x.append("foo") 219 return x 220 221 with self.assertRaisesRegex(RuntimeError, "Could not match type str"): 222 scripted = torch.jit.script(fn) 223 scripted() 224 225 def test_union_does_not_replace_existing_annotated_type_union(self): 226 def fn(): 227 x: List[int | str] = [1, "foo", 3] 228 x.append(2.0) 229 return x 230 231 with self.assertRaisesRegex(RuntimeError, "Could not match type float"): 232 scripted = torch.jit.script(fn) 233 scripted() 234 235 def test_union_does_not_replace_existing_annotated_type_empty_container(self): 236 def fn(): 237 x: List[int] = [] 238 x.append("foo") 239 return x 240 241 with self.assertRaisesRegex(RuntimeError, "Could not match type str"): 242 scripted = torch.jit.script(fn) 243 scripted() 244 245 def test_unions_of_unions_are_flattened(self): 246 @torch.jit.script 247 def fn(x: (int | str) | float) -> str: 248 return "foo" 249 250 s = fn.graph 251 252 FileCheck().check("x : Union(float, int, str)").run(s) 253 254 def test_unions_of_a_single_argument_vanish(self): 255 @torch.jit.script 256 def fn(x: Union[int]) -> str: 257 return "foo" 258 259 s = fn.graph 260 261 FileCheck().check("x : int").run(s) 262 263 def test_union_redundant_arguments_are_skipped(self): 264 @torch.jit.script 265 def fn(x: int | str | int) -> str: 266 return "foo" 267 268 s = fn.graph 269 270 FileCheck().check("x : Union(int, str)").run(s) 271 272 def test_union_redundant_arguments_are_skipped_optional(self): 273 @torch.jit.script 274 def fn(x: int | Optional[float] | Optional[int]) -> str: 275 return "foo" 276 277 s = fn.graph 278 279 FileCheck().check("x : Union(float, int, NoneType)").run(s) 280 281 def test_union_redundant_arguments_are_skipped_subtyping(self): 282 @torch.jit.script 283 def fn(x: str | Tuple[Optional[int], int] | Tuple[int, int]) -> str: 284 return "foo" 285 286 s = fn.graph 287 288 FileCheck().check("x : Union((int?, int), str)").run(s) 289 290 def test_union_redundant_arguments_are_skipped_container(self): 291 @torch.jit.script 292 def fn(x: List[str] | List[float] | List[str]) -> str: 293 return "foo" 294 295 s = fn.graph 296 297 FileCheck().check("x : Union(float[], str[])").run(s) 298 299 def test_union_argument_order_is_ignored(self): 300 @torch.jit.script 301 def fn1(x: int | str) -> str: 302 return "foo" 303 304 @torch.jit.script 305 def fn2(x: str | int) -> str: 306 return "foo" 307 308 for s in (fn1.graph, fn2.graph): 309 FileCheck().check("x : Union(int, str)").run(s) 310 311 def test_union_argument_order_is_ignored_container(self): 312 @torch.jit.script 313 def fn1(x: List[str] | List[int]) -> str: 314 return "foo" 315 316 @torch.jit.script 317 def fn2(x: List[int] | List[str]) -> str: 318 return "foo" 319 320 for s in (fn1.graph, fn2.graph): 321 FileCheck().check("x : Union(int[], str[])").run(s) 322 323 def test_union_T_None_is_equivalent_to_optional_T(self): 324 @torch.jit.script 325 def inner(x: int | None) -> int: 326 if x is not None: 327 return x 328 else: 329 return 5 330 331 @torch.jit.script 332 def fn1() -> int: 333 a: Optional[int] = 5 334 b: Optional[int] = None 335 a_ = inner(a) 336 b_ = inner(b) 337 return a_ + b_ 338 339 self.assertEqual(fn1(), 10) 340 341 @torch.jit.script 342 def inner2(x: Optional[int]) -> int: 343 if x is not None: 344 return x 345 else: 346 return 5 347 348 @torch.jit.script 349 def fn2() -> int: 350 a: int | None = 5 351 b: int | None = None 352 a_ = inner(a) 353 b_ = inner(b) 354 return a_ + b_ 355 356 self.assertEqual(fn2(), 10) 357 358 @unittest.expectedFailure 359 def test_union_optional_of_union_return(self): 360 @torch.jit.script 361 def fn() -> None | str | int: 362 y: Optional[int | str] = "foo" 363 return y 364 365 @unittest.expectedFailure 366 def test_union_optional_of_union_is_flattened(self): 367 @torch.jit.script 368 def fn(flag: int) -> str | int | None: 369 y: int | str | None = "foo" 370 if flag == 0: 371 x: Optional[int | str] = y 372 elif flag == 1: 373 x: Optional[int | str] = 1 374 else: 375 x: Optional[int | str] = None 376 return x 377 378 # Can't use `checkScript` because it will flag the fact that 379 # the original code has `Optional[Union[int, str]]` but the 380 # saved/loaded code has `Union[int, NoneType, str]` (even 381 # though this is exactly what we want) 382 self.assertEqual(fn(0), "foo") 383 self.assertEqual(fn(1), 1) 384 self.assertEqual(fn(2), None) 385 386 buffer = io.BytesIO() 387 torch.jit.save(fn, buffer) 388 buffer = io.BytesIO(buffer.getvalue()) 389 l = torch.jit.load(buffer) 390 391 s = l.code 392 393 FileCheck().check("Union[int, NoneType, str]").check( 394 "Union[int, NoneType, str]" 395 ).run(s) 396 397 def test_union_subclasses_larger_union(self): 398 def fn() -> int | str | torch.Tensor: 399 x: int | str = "foo" 400 return x 401 402 self.checkScript(fn, ()) 403 404 # TODO: We would like to eventually support this. The issue is being 405 # tracked at https://github.com/pytorch/pytorch/issues/58167 406 def test_union_as_dict_key(self): 407 def fn(): 408 x: Dict[int | str, str] = {} 409 x["foo"] = "bar" 410 x[1] = 2 411 return x[1] 412 413 with self.assertRaisesRegex( 414 RuntimeError, 415 "only int, float, " 416 "complex, Tensor, device and string keys " 417 "are supported", 418 ): 419 torch.jit.script(fn) 420 421 def test_union_as_dict_value(self): 422 def fn(): 423 x: Dict[str, int | str] = {} 424 x["foo"] = "bar" 425 x["baz"] = 2 426 return x["baz"] 427 428 self.checkScript(fn, ()) 429 430 def test_union_module_with_union_instance_variable(self): 431 class M(torch.nn.Module): 432 x: int | str 433 434 def __init__(self, x: int | str): 435 super().__init__() 436 self.x: int | str = x 437 438 def forward(self, y: int | str): 439 self.x = y 440 return self.x 441 442 self.checkModule( 443 M( 444 2, 445 ), 446 (1,), 447 ) 448 self.checkModule(M("bar"), ("foo",)) 449 450 def test_union_module_with_union_class_variable(self): 451 class M(torch.nn.Module): 452 x: int | str = "foo" 453 454 def __init__(self, y: int): 455 super().__init__() 456 x = y 457 458 def forward(self, z: str): 459 x = z 460 return x 461 462 self.checkModule(M(1), ("foo",)) 463 464 def test_union_type_refinement(self): 465 def fn(x: int | str) -> str: 466 if isinstance(x, str): 467 z = x + "bar" 468 return x 469 else: 470 return "baz" 471 472 self.checkScript(fn, ("foo",)) 473 self.checkScript(fn, (1,)) 474 475 def test_union_type_refinement_union_rhs(self): 476 def fn(x: int) -> str: 477 if torch.jit.isinstance(x, int | str): 478 return "bar" 479 else: 480 return "baz" 481 482 self.checkScript(fn, (1,)) 483 484 def test_union_type_refinement_tuple_rhs(self): 485 def fn(x: int | float | List[str]) -> str: 486 if isinstance(x, (int, float)): 487 if isinstance(x, int): 488 return str(x) 489 else: 490 return "foo" 491 else: 492 if len(x): 493 return x[0] 494 else: 495 return "bar" 496 497 self.checkScript(fn, (1,)) 498 self.checkScript(fn, (1.0,)) 499 self.checkScript(fn, (["a", "b", "c"],)) 500 501 def test_union_type_refinement_tuple_rhs_noncontained_type(self): 502 def fn(x: int | List[str]) -> str: 503 if isinstance(x, (int, float)): 504 y = x + x 505 return str(y) 506 else: 507 if len(x): 508 return x[0] 509 else: 510 return "bar" 511 512 self.checkScript(fn, (1,)) 513 self.checkScript(fn, (["a", "b", "c"],)) 514 515 def test_union_type_refinement_tuple_rhs_union(self): 516 @torch.jit.script 517 def fn(x: int) -> str: 518 if torch.jit.isinstance(x, (int | str, float)): 519 y = x + x 520 return str(y) 521 else: 522 return "foo" 523 524 # TODO: There's currently an unrelated bug in 525 # `torch.jit.isinstance` that makes it fail for tuple literals. 526 # Posted here: https://github.com/pytorch/pytorch/issues/60095 527 # Change `assertEqual` to `checkScript` when the bug is fixed 528 self.assertEqual(fn(1), "2") 529 530 def test_union_type_refinement_statically_false(self): 531 @torch.jit.script 532 def fn(x: int) -> str: 533 if torch.jit.isinstance(x, (str | float, List[str], str)): 534 z = x + "foo" 535 return z 536 else: 537 return "bar" 538 539 s = fn.graph 540 541 # Check that we don't have any branching statements 542 FileCheck().check_not("block0()").check_not("block1()").run(s) 543 544 def test_union_type_refinement_statically_true(self): 545 @torch.jit.script 546 def fn(x: List[int] | int) -> List[int] | int: 547 if not torch.jit.isinstance(x, (int, List[int])): 548 return x 549 else: 550 l = [1, 2, 3] 551 y: List[int] | int = l 552 return y 553 554 s = fn.graph 555 556 # Check that we don't have any branching statements 557 FileCheck().check_not("block0()").check_not("block1()").run(s) 558 559 def test_union_type_refinement_partial_static_refinement_tuple_rhs(self): 560 def fn(x: List[int] | int) -> int: 561 if torch.jit.isinstance(x, (int, float, str)): 562 # We should know that `x` is an `int` here 563 z = x + 1 564 return z 565 else: 566 return 100 567 568 self.checkScript(fn, ([1, 2, 3],)) 569 self.checkScript(fn, (1,)) 570 571 def test_union_type_refinement_partial_static_refinement_union_rhs(self): 572 def fn(x: List[int] | int) -> int: 573 if torch.jit.isinstance(x, int | float | str): 574 # We should know that `x` is an `int` here 575 z = x + 1 576 return z 577 else: 578 return 100 579 580 self.checkScript(fn, ([1, 2, 3],)) 581 self.checkScript(fn, (1,)) 582 583 def test_union_type_refinement_internal_declaration(self): 584 def fn(flag: bool) -> str: 585 x: int | str | None = None 586 if flag: 587 y = "foo" 588 else: 589 y = 1 590 if isinstance(x, str): 591 return x 592 else: 593 return "bar" 594 595 self.checkScript(fn, (True,)) 596 self.checkScript(fn, (False,)) 597 598 def test_union_branching_with_union_return_and_homogenous_types(self): 599 def fn(x: int) -> int | str: 600 if x % 2: 601 return "foo" 602 else: 603 return "bar" 604 605 self.checkScript(fn, (1,)) 606 self.checkScript(fn, (8,)) 607 608 def test_union_branching_does_not_autoinfer_undeclared_union(self): 609 def fn(x: int) -> str: 610 if x % 2: 611 y = "foo" 612 else: 613 y = x 614 if isinstance(y, str): 615 return y 616 else: 617 return "bar" 618 619 with self.assertRaisesRegex( 620 RuntimeError, 621 "y is set to type str" 622 " in the true branch and type int " 623 "in the false branch", 624 ): 625 torch.jit.script(fn) 626 627 def test_union_branching_does_not_widen_existing_inferred_type(self): 628 def fn(x: int) -> str: 629 y = "foo" 630 if x % 2: 631 y = "bar" 632 else: 633 y = x 634 if isinstance(y, str): 635 return y 636 else: 637 return "baz" 638 639 with self.assertRaisesRegex( 640 RuntimeError, 641 "previously had type " 642 "str but is now being assigned to a" 643 " value of type int", 644 ): 645 torch.jit.script(fn) 646 647 def test_union_schema_matching_on_internal_type(self): 648 def fn(x: List[int] | Dict[str, int]) -> int: 649 if torch.jit.isinstance(x, List[int]): 650 return x[0] 651 else: 652 return list(x.values())[0] 653 654 self.checkScript(fn, ([1, 2, 3],)) 655 self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) 656 657 def test_union_subtractive_refinement(self): 658 def fn(x: List[int] | int) -> int: 659 if not isinstance(x, int): 660 x.append(1) 661 return x[0] 662 else: 663 return x 664 665 self.checkScript(fn, (1,)) 666 self.checkScript(fn, ([1, 2, 3],)) 667 668 def test_union_subtractive_refinement_with_container(self): 669 def fn(x: List[int] | int) -> int: 670 if not torch.jit.isinstance(x, List[int]): 671 return x 672 else: 673 x.append(1) 674 return x[0] 675 676 self.checkScript(fn, (1,)) 677 self.checkScript(fn, ([1, 2, 3],)) 678 679 def test_union_memory_aliasing(self): 680 def fn(): 681 x: List[torch.Tensor] = [] 682 z: List[Optional[List[torch.Tensor]]] = [] 683 z.append(x) 684 x_alias = z[0] 685 if torch.jit.isinstance(x_alias, List[torch.Tensor]): 686 x_alias.append(torch.tensor(3)) 687 return x 688 689 self.checkScript(fn, ()) 690 691 def test_union_serialization_preserves_type_annotations(self): 692 # This function will fail after being torch.jit.save'd and 693 # torch.jit.load'd if the type annotations aren't preserved 694 # for Union during serialization. We need the `Union[str, int]` 695 # annotation to make sure that `y` is typed as a Union instead 696 # of as a str in one branch and an int in the other 697 def fn(x: int) -> str: 698 if x % 2: 699 y: str | int = "bar" 700 else: 701 y: str | int = x 702 if isinstance(y, str): 703 return y 704 else: 705 return "baz" 706 707 self.checkScript(fn, (1,)) 708 self.checkScript(fn, (8,)) 709 710 def _assert_passes(self, template: str, ann: str, lhs: str): 711 code = template.format(ann=ann, lhs=lhs) 712 self.checkScript(code, (), name="fn") 713 714 def _assert_raises(self, template: str, ann: str, lhs: str, msg: str): 715 code = template.format(ann=ann, lhs=lhs) 716 with self.assertRaisesRegex(RuntimeError, msg): 717 cu = torch.jit.CompilationUnit(code, _frames_up=1) 718 string_frontend = getattr(cu, "fn") # noqa: B009 719 720 def test_union_with_list_assignment(self): 721 template = dedent( 722 """ 723 def fn(): 724 x: {ann} = {lhs} 725 if torch.jit.isinstance(x, List[torch.Tensor]): 726 x.append(torch.tensor(3)) 727 return x 728 """ 729 ) 730 731 lhs = { 732 "list_literal_empty": "[]", 733 "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]", 734 "list_literal_of_str": '["foo", "bar", "baz"]', 735 "list_literal_of_mixed": "[torch.arange(5), 1]", 736 "list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]", 737 "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]', 738 "list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]", 739 } 740 741 """ 742 List[str] | List[torch.Tensor] 743 """ 744 self._assert_raises( 745 template, 746 "List[str] | List[torch.Tensor]", 747 lhs["list_literal_empty"], 748 "there are multiple possible List type " 749 "candidates in the Union annotation", 750 ) 751 752 self._assert_passes( 753 template, "List[str] | List[torch.Tensor]", lhs["list_literal_of_tensor"] 754 ) 755 756 self._assert_passes( 757 template, "List[str] | List[torch.Tensor]", lhs["list_literal_of_str"] 758 ) 759 760 self._assert_raises( 761 template, 762 "List[str] | List[torch.Tensor]", 763 lhs["list_literal_of_mixed"], 764 "none of those types match the types of the" " given list elements", 765 ) 766 767 self._assert_passes( 768 template, 769 "List[str] | List[torch.Tensor]", 770 lhs["list_comprehension_of_tensor"], 771 ) 772 773 self._assert_passes( 774 template, "List[str] | List[torch.Tensor]", lhs["list_comprehension_of_str"] 775 ) 776 777 # TODO: Support mixed list comprehensions 778 self._assert_raises( 779 template, 780 "List[str] | List[torch.Tensor]", 781 lhs["list_comprehension_of_mixed"], 782 "Arguments for call are not valid", 783 ) 784 785 """ 786 int | torch.Tensor 787 """ 788 self._assert_raises( 789 template, 790 "int | torch.Tensor", 791 lhs["list_literal_empty"], 792 "Expected an Union type annotation with an " "inner List type", 793 ) 794 795 self._assert_raises( 796 template, 797 "int | torch.Tensor", 798 lhs["list_literal_of_tensor"], 799 "Expected an Union type annotation with an " "inner List type", 800 ) 801 802 self._assert_raises( 803 template, 804 "int | torch.Tensor", 805 lhs["list_comprehension_of_tensor"], 806 "Expected an Union type annotation with an " "inner List type", 807 ) 808 809 """ 810 List[torch.Tensor] | int 811 """ 812 self._assert_passes( 813 template, "List[torch.Tensor] | int", lhs["list_literal_empty"] 814 ) 815 816 self._assert_passes( 817 template, "List[torch.Tensor] | int", lhs["list_literal_of_tensor"] 818 ) 819 820 self._assert_raises( 821 template, 822 "List[torch.Tensor] | int", 823 lhs["list_literal_of_str"], 824 r"List type annotation `List\[Tensor\]` did " 825 "not match the types of the given list " 826 "elements", 827 ) 828 829 self._assert_raises( 830 template, 831 "List[torch.Tensor] | int", 832 lhs["list_literal_of_mixed"], 833 r"List type annotation `List\[Tensor\]` did " 834 "not match the types of the given list " 835 "elements", 836 ) 837 838 self._assert_passes( 839 template, "List[torch.Tensor] | int", lhs["list_comprehension_of_tensor"] 840 ) 841 842 self._assert_raises( 843 template, 844 "List[torch.Tensor] | int", 845 lhs["list_comprehension_of_str"], 846 r"List type annotation `List\[Tensor\]` did " 847 "not match the types of the given list " 848 "elements", 849 ) 850 851 # TODO(@ansley): Support mixed list comprehensions 852 self._assert_raises( 853 template, 854 "List[torch.Tensor] | int", 855 lhs["list_comprehension_of_mixed"], 856 "Arguments for call are not valid", 857 ) 858 859 def test_union_with_dict_assignment(self): 860 template = dedent( 861 """ 862 def fn(): 863 x: {ann} = {lhs} 864 if torch.jit.isinstance(x, Dict[str, torch.Tensor]): 865 x["foo"] = torch.tensor(3) 866 return x 867 """ 868 ) 869 870 lhs = { 871 "dict_literal_empty": "{}", 872 "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}', 873 "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}', 874 "dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}', 875 "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \ 876 zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}', 877 "dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \ 878 zip(["foo", "bar"], [1, 2]}', 879 "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \ 880 zip(["foo", "bar"], [torch.arange(3), 2])}', 881 "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))", 882 "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])', 883 "dict_keyword_with_empty_iterable": "dict([])", 884 "dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])', 885 "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})', 886 "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))', 887 } 888 889 """ 890 Dict[str, torch.Tensor] | Dict[str, int] 891 """ 892 self._assert_raises( 893 template, 894 "List[str] | List[torch.Tensor]", 895 lhs["dict_literal_empty"], 896 "Expected an Union type annotation with an " "inner Dict type", 897 ) 898 899 self._assert_passes( 900 template, 901 "Dict[str, torch.Tensor] | Dict[str, int]", 902 lhs["dict_literal_of_str_tensor"], 903 ) 904 905 self._assert_passes( 906 template, 907 "Dict[str, torch.Tensor] | Dict[str, int]", 908 lhs["dict_literal_of_str_int"], 909 ) 910 911 self._assert_raises( 912 template, 913 "Dict[str, torch.Tensor] | Dict[str, int]", 914 lhs["dict_literal_of_mixed"], 915 "none of those dict types can hold the " 916 "types of the given keys and values", 917 ) 918 919 # TODO: String frontend does not support tuple unpacking 920 # https://github.com/pytorch/pytorch/issues/64096 921 # self._assert_passes(template, "Dict[str, torch.Tensor] | Dict[str, int]", 922 # lhs["dict_comprehension_of_str_tensor"]) 923 924 # self._assert_passes(template, "Dict[str, torch.Tensor] | Dict[str, int]", 925 # lhs["dict_comprehension_of_str_int"]) 926 927 # self._assert_raises(template, "Dict[str, torch.Tensor] | Dict[str, int]", 928 # lhs["dict_comprehension_of_mixed"], 929 # "foobar") 930 931 # self._assert_passes(template, 932 # "Dict[str, torch.Tensor] | Dict[str, int]", 933 # lhs["dict_keyword_with_internal_aggregate_function"]) 934 935 # TODO(@ansley): Follow-up project needed for full type 936 # inference with dict keyword (supported for dict comprehension 937 # and dict literal already; should not be a blocker for anyone) 938 self._assert_raises( 939 template, 940 "Dict[str, torch.Tensor] | Dict[str, int]", 941 lhs["dict_keyword"], 942 "full type inference is not yet supported", 943 ) 944 945 self._assert_raises( 946 template, 947 "Dict[str, torch.Tensor] | Dict[str, int]", 948 lhs["dict_keyword_with_iterable"], 949 "full type inference is not yet supported", 950 ) 951 952 self._assert_raises( 953 template, 954 "Dict[str, torch.Tensor] | Dict[str, int]", 955 lhs["dict_keyword_with_empty_iterable"], 956 "full type inference is not yet supported", 957 ) 958 959 self._assert_raises( 960 template, 961 "Dict[str, torch.Tensor] | Dict[str, int]", 962 lhs["dict_keyword_with_mapping"], 963 "full type inference is not yet supported", 964 ) 965 966 self._assert_raises( 967 template, 968 "Dict[str, torch.Tensor] | Dict[str, int]", 969 lhs["dict_keyword_with_mapping_and_kwargs"], 970 "full type inference is not yet supported", 971 ) 972 973 """ 974 int | torch.Tensor 975 """ 976 self._assert_raises( 977 template, 978 "int | torch.Tensor", 979 lhs["dict_literal_empty"], 980 "Expected an Union type annotation with " "an inner Dict type", 981 ) 982 983 self._assert_raises( 984 template, 985 "int | torch.Tensor", 986 lhs["dict_literal_of_str_tensor"], 987 "Expected an Union type annotation with " "an inner Dict type", 988 ) 989 990 # See above--string frontend does not support tuple unpacking 991 # self._assert_raises(template, "int | torch.Tensor", 992 # lhs["dict_comprehension_of_tensor"], 993 # "foobar") 994 995 """ 996 Dict[str, torch.Tensor] | int 997 """ 998 self._assert_passes( 999 template, "Dict[str, torch.Tensor] | int", lhs["dict_literal_empty"] 1000 ) 1001 1002 self._assert_passes( 1003 template, "Dict[str, torch.Tensor] | int", lhs["dict_literal_of_str_tensor"] 1004 ) 1005 1006 self._assert_raises( 1007 template, 1008 "Dict[str, torch.Tensor] | int", 1009 lhs["dict_literal_of_str_int"], 1010 "Type annotation was inferred to be " 1011 r"`Dict\[str, Tensor\]`, but the type of " 1012 "values given by the dict literal is", 1013 ) 1014 1015 self._assert_raises( 1016 template, 1017 "Dict[str, torch.Tensor] | int", 1018 lhs["dict_literal_of_mixed"], 1019 "Type annotation was inferred to be " 1020 r"`Dict\[str, Tensor\]`, but the type of " 1021 "values given by the dict literal is", 1022 ) 1023 1024 self._assert_passes( 1025 template, "Dict[str, torch.Tensor] | int", lhs["dict_keyword"] 1026 ) 1027 1028 self._assert_passes( 1029 template, "Dict[str, torch.Tensor] | int", lhs["dict_keyword_with_iterable"] 1030 ) 1031 1032 self._assert_passes( 1033 template, 1034 "Dict[str, torch.Tensor] | int", 1035 lhs["dict_keyword_with_empty_iterable"], 1036 ) 1037 1038 self._assert_passes( 1039 template, "Dict[str, torch.Tensor] | int", lhs["dict_keyword_with_mapping"] 1040 ) 1041 1042 self._assert_passes( 1043 template, 1044 "Dict[str, torch.Tensor] | int", 1045 lhs["dict_keyword_with_mapping_and_kwargs"], 1046 ) 1047 1048 # See above--string frontend does not support tuple unpacking 1049 # self._assert_passes(template, 1050 # "Dict[str, torch.Tensor] | int", 1051 # lhs["dict_keyword_with_internal_aggregate_function"]) 1052 # 1053 # self._assert_passes(template, 1054 # "Dict[str, torch.Tensor] | int", 1055 # lhs["dict_comprehension_of_str_tensor"]) 1056 1057 # self._assert_raises(template, 1058 # "Dict[str, torch.Tensor] | int", 1059 # lhs["dict_comprehension_of_str_int"], 1060 # "foobar") 1061 1062 # self._assert_raises(template, 1063 # "Dict[str, torch.Tensor] | int", 1064 # lhs["dict_comprehension_of_mixed"], 1065 # "foobar") 1066