xref: /aosp_15_r20/external/pytorch/test/jit/test_union_pep604.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import io
4import os
5import sys
6import unittest
7from enum import Enum
8from textwrap import dedent
9from typing import Dict, List, Optional, Tuple, Union
10
11import torch
12from torch.testing import FileCheck
13
14
15# Make the helper files in test/ importable
16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17sys.path.append(pytorch_test_dir)
18from torch.testing._internal.jit_utils import JitTestCase, make_global
19
20
21if __name__ == "__main__":
22    raise RuntimeError(
23        "This test file is not meant to be run directly, use:\n\n"
24        "\tpython test/test_jit.py TESTNAME\n\n"
25        "instead."
26    )
27
28
29@unittest.skipIf(sys.version_info < (3, 10), "Requires Python 3.10")
30class TestUnion(JitTestCase):
31    """
32    This class tests the functionality of `Union`.
33
34    Note: It's important to be able to refine the type of a `Union` to
35    one of its internal types. Currently, there are differences in the
36    way Python expects `isinstance` checks and the way TorchScript
37    expects `isinstance` checks. This means that we can't use
38    `checkScript` in our test cases because either the eager mode or the
39    script mode wouldn't run! So, some test cases have separate but
40    equivalent functions to emulate `checkScript`.
41    """
42
43    def test_check_union_annotation(self):
44        def test_func(a: int | float, b: Optional[int]):
45            return 0
46
47        scripted_func = torch.jit.script(test_func)
48        graph_rep = str(scripted_func.graph)
49        code_rep = str(scripted_func.code)
50        # TS graph IR for Union should be annotated as Union()
51        FileCheck().check("Union(").check("int?").run(graph_rep)
52        # Serialized code for Union should be annotated as Union[]
53        FileCheck().check("Union[").check("Optional[int]").run(code_rep)
54        self.checkScript(test_func, (5, 6))
55        # this shouldn't error out
56        torch._C.parse_ir(str(scripted_func.graph))
57
58    def test_union_with_scalar_values(self):
59        def fn(x: int | float) -> str:
60            return "foo"
61
62        self.checkScript(fn, (1,))
63        self.checkScript(fn, (1.0,))
64
65        scripted = torch.jit.script(fn)
66
67        with self.assertRaisesRegex(
68            RuntimeError,
69            "Expected a member of"
70            r" Union\[float, int\] but "
71            "instead found type str",
72        ):
73            scripted("1")
74
75    def test_union_with_collections(self):
76        def fn(x: Dict[str, int] | List[int]) -> str:
77            return "foo"
78
79        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
80        self.checkScript(fn, ([1, 2, 3],))
81
82        scripted = torch.jit.script(fn)
83
84        with self.assertRaisesRegex(
85            RuntimeError,
86            "Expected a member of"
87            r" Union\[List\[int\], Dict\[str, "
88            r"int\]\] but instead found type "
89            r"Dict\[str, str\]",
90        ):
91            scripted({"foo": "bar", "baz": "qux"})
92
93        with self.assertRaisesRegex(
94            RuntimeError,
95            "Expected a member of"
96            r" Union\[List\[int\], Dict\[str, "
97            r"int\]\] but instead found type "
98            r"List\[str\]",
99        ):
100            scripted(["foo", "bar", "baz"])
101
102        with self.assertRaisesRegex(
103            RuntimeError,
104            "Expected a member of"
105            r" Union\[List\[int\], Dict\[str, "
106            r"int\]\] but instead found type "
107            "str",
108        ):
109            scripted("1")
110
111    def test_union_with_enum(self):
112        class Color(Enum):
113            RED = 1
114            GREEN = 2
115
116        make_global(Color)
117
118        def fn(x: str | Color) -> str:
119            return "foo"
120
121        self.checkScript(fn, (Color.RED,))
122        self.checkScript(fn, ("red",))
123
124        scripted = torch.jit.script(fn)
125
126        with self.assertRaisesRegex(
127            RuntimeError,
128            "Expected a member of"
129            r" Union\[__torch__.jit.test_union_pep604."
130            r"Color, str\] but instead found "
131            "type int",
132        ):
133            scripted(1)
134
135    def test_union_in_class_constructor(self):
136        @torch.jit.script  # noqa: B903
137        class A:  # noqa: B903
138            def __init__(self, x: int | str) -> None:
139                self.x = x
140
141        def fn(x: str | int) -> A:
142            return A(x)
143
144        self.assertEqual(fn("foo").x, "foo")
145        self.assertEqual(fn(1).x, 1)
146
147        scripted = torch.jit.script(fn)
148
149        with self.assertRaisesRegex(
150            RuntimeError,
151            "Expected a member of"
152            r" Union\[int, str\] but instead "
153            r"found type List\[str\]",
154        ):
155            scripted(["foo", "bar", "baz"])
156
157    def test_union_return_type(self):
158        def fn(x: int) -> int | str:
159            return "foo"
160
161        self.checkScript(fn, (1,))
162
163    def test_union_as_annotation(self):
164        def fn() -> int | str:
165            x: int | str = "foo"
166            return x
167
168        self.checkScript(fn, ())
169
170    def test_union_as_annotation_in_typed_container(self):
171        def fn() -> None:
172            l: List[int | str] = []
173            u1: int | str = "foo"
174            u2: int | str = 1
175            l.append(u1)
176            l.append(u2)
177
178        self.checkScript(fn, ())
179
180    def test_union_as_annotation_py2(self):
181        def fn():
182            # type: () -> int | str
183            x: int | str = "foo"
184            return x
185
186        self.checkScript(fn, ())
187
188    def test_union_as_internal_tuple_type(self):
189        def fn():
190            t: Tuple[int | str, int | str] = (1, "foo")
191            return t
192
193        self.checkScript(fn, ())
194
195    def test_union_variable_can_be_reassigned(self):
196        @torch.jit.script
197        def aux1(i: int):
198            return int(i**2)
199
200        @torch.jit.script
201        def aux2(s: str):
202            return s + s
203
204        def fn() -> int | str:
205            x: int | str = "foo"
206            i: int = 1
207            x = i
208            y: int = aux1(x)
209            z: str = aux2(str(y))
210            x = z
211            return x
212
213        self.checkScript(fn, ())
214
215    def test_union_does_not_replace_existing_annotated_type(self):
216        def fn():
217            x: List[int] = [1, 2, 3]
218            x.append("foo")
219            return x
220
221        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
222            scripted = torch.jit.script(fn)
223            scripted()
224
225    def test_union_does_not_replace_existing_annotated_type_union(self):
226        def fn():
227            x: List[int | str] = [1, "foo", 3]
228            x.append(2.0)
229            return x
230
231        with self.assertRaisesRegex(RuntimeError, "Could not match type float"):
232            scripted = torch.jit.script(fn)
233            scripted()
234
235    def test_union_does_not_replace_existing_annotated_type_empty_container(self):
236        def fn():
237            x: List[int] = []
238            x.append("foo")
239            return x
240
241        with self.assertRaisesRegex(RuntimeError, "Could not match type str"):
242            scripted = torch.jit.script(fn)
243            scripted()
244
245    def test_unions_of_unions_are_flattened(self):
246        @torch.jit.script
247        def fn(x: (int | str) | float) -> str:
248            return "foo"
249
250        s = fn.graph
251
252        FileCheck().check("x : Union(float, int, str)").run(s)
253
254    def test_unions_of_a_single_argument_vanish(self):
255        @torch.jit.script
256        def fn(x: Union[int]) -> str:
257            return "foo"
258
259        s = fn.graph
260
261        FileCheck().check("x : int").run(s)
262
263    def test_union_redundant_arguments_are_skipped(self):
264        @torch.jit.script
265        def fn(x: int | str | int) -> str:
266            return "foo"
267
268        s = fn.graph
269
270        FileCheck().check("x : Union(int, str)").run(s)
271
272    def test_union_redundant_arguments_are_skipped_optional(self):
273        @torch.jit.script
274        def fn(x: int | Optional[float] | Optional[int]) -> str:
275            return "foo"
276
277        s = fn.graph
278
279        FileCheck().check("x : Union(float, int, NoneType)").run(s)
280
281    def test_union_redundant_arguments_are_skipped_subtyping(self):
282        @torch.jit.script
283        def fn(x: str | Tuple[Optional[int], int] | Tuple[int, int]) -> str:
284            return "foo"
285
286        s = fn.graph
287
288        FileCheck().check("x : Union((int?, int), str)").run(s)
289
290    def test_union_redundant_arguments_are_skipped_container(self):
291        @torch.jit.script
292        def fn(x: List[str] | List[float] | List[str]) -> str:
293            return "foo"
294
295        s = fn.graph
296
297        FileCheck().check("x : Union(float[], str[])").run(s)
298
299    def test_union_argument_order_is_ignored(self):
300        @torch.jit.script
301        def fn1(x: int | str) -> str:
302            return "foo"
303
304        @torch.jit.script
305        def fn2(x: str | int) -> str:
306            return "foo"
307
308        for s in (fn1.graph, fn2.graph):
309            FileCheck().check("x : Union(int, str)").run(s)
310
311    def test_union_argument_order_is_ignored_container(self):
312        @torch.jit.script
313        def fn1(x: List[str] | List[int]) -> str:
314            return "foo"
315
316        @torch.jit.script
317        def fn2(x: List[int] | List[str]) -> str:
318            return "foo"
319
320        for s in (fn1.graph, fn2.graph):
321            FileCheck().check("x : Union(int[], str[])").run(s)
322
323    def test_union_T_None_is_equivalent_to_optional_T(self):
324        @torch.jit.script
325        def inner(x: int | None) -> int:
326            if x is not None:
327                return x
328            else:
329                return 5
330
331        @torch.jit.script
332        def fn1() -> int:
333            a: Optional[int] = 5
334            b: Optional[int] = None
335            a_ = inner(a)
336            b_ = inner(b)
337            return a_ + b_
338
339        self.assertEqual(fn1(), 10)
340
341        @torch.jit.script
342        def inner2(x: Optional[int]) -> int:
343            if x is not None:
344                return x
345            else:
346                return 5
347
348        @torch.jit.script
349        def fn2() -> int:
350            a: int | None = 5
351            b: int | None = None
352            a_ = inner(a)
353            b_ = inner(b)
354            return a_ + b_
355
356        self.assertEqual(fn2(), 10)
357
358    @unittest.expectedFailure
359    def test_union_optional_of_union_return(self):
360        @torch.jit.script
361        def fn() -> None | str | int:
362            y: Optional[int | str] = "foo"
363            return y
364
365    @unittest.expectedFailure
366    def test_union_optional_of_union_is_flattened(self):
367        @torch.jit.script
368        def fn(flag: int) -> str | int | None:
369            y: int | str | None = "foo"
370            if flag == 0:
371                x: Optional[int | str] = y
372            elif flag == 1:
373                x: Optional[int | str] = 1
374            else:
375                x: Optional[int | str] = None
376            return x
377
378        # Can't use `checkScript` because it will flag the fact that
379        # the original code has `Optional[Union[int, str]]` but the
380        # saved/loaded code has `Union[int, NoneType, str]` (even
381        # though this is exactly what we want)
382        self.assertEqual(fn(0), "foo")
383        self.assertEqual(fn(1), 1)
384        self.assertEqual(fn(2), None)
385
386        buffer = io.BytesIO()
387        torch.jit.save(fn, buffer)
388        buffer = io.BytesIO(buffer.getvalue())
389        l = torch.jit.load(buffer)
390
391        s = l.code
392
393        FileCheck().check("Union[int, NoneType, str]").check(
394            "Union[int, NoneType, str]"
395        ).run(s)
396
397    def test_union_subclasses_larger_union(self):
398        def fn() -> int | str | torch.Tensor:
399            x: int | str = "foo"
400            return x
401
402        self.checkScript(fn, ())
403
404    # TODO: We would like to eventually support this. The issue is being
405    # tracked at https://github.com/pytorch/pytorch/issues/58167
406    def test_union_as_dict_key(self):
407        def fn():
408            x: Dict[int | str, str] = {}
409            x["foo"] = "bar"
410            x[1] = 2
411            return x[1]
412
413        with self.assertRaisesRegex(
414            RuntimeError,
415            "only int, float, "
416            "complex, Tensor, device and string keys "
417            "are supported",
418        ):
419            torch.jit.script(fn)
420
421    def test_union_as_dict_value(self):
422        def fn():
423            x: Dict[str, int | str] = {}
424            x["foo"] = "bar"
425            x["baz"] = 2
426            return x["baz"]
427
428        self.checkScript(fn, ())
429
430    def test_union_module_with_union_instance_variable(self):
431        class M(torch.nn.Module):
432            x: int | str
433
434            def __init__(self, x: int | str):
435                super().__init__()
436                self.x: int | str = x
437
438            def forward(self, y: int | str):
439                self.x = y
440                return self.x
441
442        self.checkModule(
443            M(
444                2,
445            ),
446            (1,),
447        )
448        self.checkModule(M("bar"), ("foo",))
449
450    def test_union_module_with_union_class_variable(self):
451        class M(torch.nn.Module):
452            x: int | str = "foo"
453
454            def __init__(self, y: int):
455                super().__init__()
456                x = y
457
458            def forward(self, z: str):
459                x = z
460                return x
461
462        self.checkModule(M(1), ("foo",))
463
464    def test_union_type_refinement(self):
465        def fn(x: int | str) -> str:
466            if isinstance(x, str):
467                z = x + "bar"
468                return x
469            else:
470                return "baz"
471
472        self.checkScript(fn, ("foo",))
473        self.checkScript(fn, (1,))
474
475    def test_union_type_refinement_union_rhs(self):
476        def fn(x: int) -> str:
477            if torch.jit.isinstance(x, int | str):
478                return "bar"
479            else:
480                return "baz"
481
482        self.checkScript(fn, (1,))
483
484    def test_union_type_refinement_tuple_rhs(self):
485        def fn(x: int | float | List[str]) -> str:
486            if isinstance(x, (int, float)):
487                if isinstance(x, int):
488                    return str(x)
489                else:
490                    return "foo"
491            else:
492                if len(x):
493                    return x[0]
494                else:
495                    return "bar"
496
497        self.checkScript(fn, (1,))
498        self.checkScript(fn, (1.0,))
499        self.checkScript(fn, (["a", "b", "c"],))
500
501    def test_union_type_refinement_tuple_rhs_noncontained_type(self):
502        def fn(x: int | List[str]) -> str:
503            if isinstance(x, (int, float)):
504                y = x + x
505                return str(y)
506            else:
507                if len(x):
508                    return x[0]
509                else:
510                    return "bar"
511
512        self.checkScript(fn, (1,))
513        self.checkScript(fn, (["a", "b", "c"],))
514
515    def test_union_type_refinement_tuple_rhs_union(self):
516        @torch.jit.script
517        def fn(x: int) -> str:
518            if torch.jit.isinstance(x, (int | str, float)):
519                y = x + x
520                return str(y)
521            else:
522                return "foo"
523
524        # TODO: There's currently an unrelated bug in
525        # `torch.jit.isinstance` that makes it fail for tuple literals.
526        # Posted here: https://github.com/pytorch/pytorch/issues/60095
527        # Change `assertEqual` to `checkScript` when the bug is fixed
528        self.assertEqual(fn(1), "2")
529
530    def test_union_type_refinement_statically_false(self):
531        @torch.jit.script
532        def fn(x: int) -> str:
533            if torch.jit.isinstance(x, (str | float, List[str], str)):
534                z = x + "foo"
535                return z
536            else:
537                return "bar"
538
539        s = fn.graph
540
541        # Check that we don't have any branching statements
542        FileCheck().check_not("block0()").check_not("block1()").run(s)
543
544    def test_union_type_refinement_statically_true(self):
545        @torch.jit.script
546        def fn(x: List[int] | int) -> List[int] | int:
547            if not torch.jit.isinstance(x, (int, List[int])):
548                return x
549            else:
550                l = [1, 2, 3]
551                y: List[int] | int = l
552                return y
553
554        s = fn.graph
555
556        # Check that we don't have any branching statements
557        FileCheck().check_not("block0()").check_not("block1()").run(s)
558
559    def test_union_type_refinement_partial_static_refinement_tuple_rhs(self):
560        def fn(x: List[int] | int) -> int:
561            if torch.jit.isinstance(x, (int, float, str)):
562                # We should know that `x` is an `int` here
563                z = x + 1
564                return z
565            else:
566                return 100
567
568        self.checkScript(fn, ([1, 2, 3],))
569        self.checkScript(fn, (1,))
570
571    def test_union_type_refinement_partial_static_refinement_union_rhs(self):
572        def fn(x: List[int] | int) -> int:
573            if torch.jit.isinstance(x, int | float | str):
574                # We should know that `x` is an `int` here
575                z = x + 1
576                return z
577            else:
578                return 100
579
580        self.checkScript(fn, ([1, 2, 3],))
581        self.checkScript(fn, (1,))
582
583    def test_union_type_refinement_internal_declaration(self):
584        def fn(flag: bool) -> str:
585            x: int | str | None = None
586            if flag:
587                y = "foo"
588            else:
589                y = 1
590            if isinstance(x, str):
591                return x
592            else:
593                return "bar"
594
595        self.checkScript(fn, (True,))
596        self.checkScript(fn, (False,))
597
598    def test_union_branching_with_union_return_and_homogenous_types(self):
599        def fn(x: int) -> int | str:
600            if x % 2:
601                return "foo"
602            else:
603                return "bar"
604
605        self.checkScript(fn, (1,))
606        self.checkScript(fn, (8,))
607
608    def test_union_branching_does_not_autoinfer_undeclared_union(self):
609        def fn(x: int) -> str:
610            if x % 2:
611                y = "foo"
612            else:
613                y = x
614            if isinstance(y, str):
615                return y
616            else:
617                return "bar"
618
619        with self.assertRaisesRegex(
620            RuntimeError,
621            "y is set to type str"
622            " in the true branch and type int "
623            "in the false branch",
624        ):
625            torch.jit.script(fn)
626
627    def test_union_branching_does_not_widen_existing_inferred_type(self):
628        def fn(x: int) -> str:
629            y = "foo"
630            if x % 2:
631                y = "bar"
632            else:
633                y = x
634            if isinstance(y, str):
635                return y
636            else:
637                return "baz"
638
639        with self.assertRaisesRegex(
640            RuntimeError,
641            "previously had type "
642            "str but is now being assigned to a"
643            " value of type int",
644        ):
645            torch.jit.script(fn)
646
647    def test_union_schema_matching_on_internal_type(self):
648        def fn(x: List[int] | Dict[str, int]) -> int:
649            if torch.jit.isinstance(x, List[int]):
650                return x[0]
651            else:
652                return list(x.values())[0]
653
654        self.checkScript(fn, ([1, 2, 3],))
655        self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},))
656
657    def test_union_subtractive_refinement(self):
658        def fn(x: List[int] | int) -> int:
659            if not isinstance(x, int):
660                x.append(1)
661                return x[0]
662            else:
663                return x
664
665        self.checkScript(fn, (1,))
666        self.checkScript(fn, ([1, 2, 3],))
667
668    def test_union_subtractive_refinement_with_container(self):
669        def fn(x: List[int] | int) -> int:
670            if not torch.jit.isinstance(x, List[int]):
671                return x
672            else:
673                x.append(1)
674                return x[0]
675
676        self.checkScript(fn, (1,))
677        self.checkScript(fn, ([1, 2, 3],))
678
679    def test_union_memory_aliasing(self):
680        def fn():
681            x: List[torch.Tensor] = []
682            z: List[Optional[List[torch.Tensor]]] = []
683            z.append(x)
684            x_alias = z[0]
685            if torch.jit.isinstance(x_alias, List[torch.Tensor]):
686                x_alias.append(torch.tensor(3))
687            return x
688
689        self.checkScript(fn, ())
690
691    def test_union_serialization_preserves_type_annotations(self):
692        # This function will fail after being torch.jit.save'd and
693        # torch.jit.load'd if the type annotations aren't preserved
694        # for Union during serialization. We need the `Union[str, int]`
695        # annotation to make sure that `y` is typed as a Union instead
696        # of as a str in one branch and an int in the other
697        def fn(x: int) -> str:
698            if x % 2:
699                y: str | int = "bar"
700            else:
701                y: str | int = x
702            if isinstance(y, str):
703                return y
704            else:
705                return "baz"
706
707        self.checkScript(fn, (1,))
708        self.checkScript(fn, (8,))
709
710    def _assert_passes(self, template: str, ann: str, lhs: str):
711        code = template.format(ann=ann, lhs=lhs)
712        self.checkScript(code, (), name="fn")
713
714    def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
715        code = template.format(ann=ann, lhs=lhs)
716        with self.assertRaisesRegex(RuntimeError, msg):
717            cu = torch.jit.CompilationUnit(code, _frames_up=1)
718            string_frontend = getattr(cu, "fn")  # noqa: B009
719
720    def test_union_with_list_assignment(self):
721        template = dedent(
722            """
723            def fn():
724                x: {ann} = {lhs}
725                if torch.jit.isinstance(x, List[torch.Tensor]):
726                    x.append(torch.tensor(3))
727                return x
728        """
729        )
730
731        lhs = {
732            "list_literal_empty": "[]",
733            "list_literal_of_tensor": "[torch.arange(3), torch.arange(5)]",
734            "list_literal_of_str": '["foo", "bar", "baz"]',
735            "list_literal_of_mixed": "[torch.arange(5), 1]",
736            "list_comprehension_of_tensor": "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",
737            "list_comprehension_of_str": '[x + "!" for x in ["foo", "bar", "baz"]]',
738            "list_comprehension_of_mixed": "[torch.add(1, x) for x in [torch.arange(5), 1]]",
739        }
740
741        """
742        List[str] | List[torch.Tensor]
743        """
744        self._assert_raises(
745            template,
746            "List[str] | List[torch.Tensor]",
747            lhs["list_literal_empty"],
748            "there are multiple possible List type "
749            "candidates in the Union annotation",
750        )
751
752        self._assert_passes(
753            template, "List[str] | List[torch.Tensor]", lhs["list_literal_of_tensor"]
754        )
755
756        self._assert_passes(
757            template, "List[str] | List[torch.Tensor]", lhs["list_literal_of_str"]
758        )
759
760        self._assert_raises(
761            template,
762            "List[str] | List[torch.Tensor]",
763            lhs["list_literal_of_mixed"],
764            "none of those types match the types of the" " given list elements",
765        )
766
767        self._assert_passes(
768            template,
769            "List[str] | List[torch.Tensor]",
770            lhs["list_comprehension_of_tensor"],
771        )
772
773        self._assert_passes(
774            template, "List[str] | List[torch.Tensor]", lhs["list_comprehension_of_str"]
775        )
776
777        # TODO: Support mixed list comprehensions
778        self._assert_raises(
779            template,
780            "List[str] | List[torch.Tensor]",
781            lhs["list_comprehension_of_mixed"],
782            "Arguments for call are not valid",
783        )
784
785        """
786        int | torch.Tensor
787        """
788        self._assert_raises(
789            template,
790            "int | torch.Tensor",
791            lhs["list_literal_empty"],
792            "Expected an Union type annotation with an " "inner List type",
793        )
794
795        self._assert_raises(
796            template,
797            "int | torch.Tensor",
798            lhs["list_literal_of_tensor"],
799            "Expected an Union type annotation with an " "inner List type",
800        )
801
802        self._assert_raises(
803            template,
804            "int | torch.Tensor",
805            lhs["list_comprehension_of_tensor"],
806            "Expected an Union type annotation with an " "inner List type",
807        )
808
809        """
810        List[torch.Tensor] | int
811        """
812        self._assert_passes(
813            template, "List[torch.Tensor] | int", lhs["list_literal_empty"]
814        )
815
816        self._assert_passes(
817            template, "List[torch.Tensor] | int", lhs["list_literal_of_tensor"]
818        )
819
820        self._assert_raises(
821            template,
822            "List[torch.Tensor] | int",
823            lhs["list_literal_of_str"],
824            r"List type annotation `List\[Tensor\]` did "
825            "not match the types of the given list "
826            "elements",
827        )
828
829        self._assert_raises(
830            template,
831            "List[torch.Tensor] | int",
832            lhs["list_literal_of_mixed"],
833            r"List type annotation `List\[Tensor\]` did "
834            "not match the types of the given list "
835            "elements",
836        )
837
838        self._assert_passes(
839            template, "List[torch.Tensor] | int", lhs["list_comprehension_of_tensor"]
840        )
841
842        self._assert_raises(
843            template,
844            "List[torch.Tensor] | int",
845            lhs["list_comprehension_of_str"],
846            r"List type annotation `List\[Tensor\]` did "
847            "not match the types of the given list "
848            "elements",
849        )
850
851        # TODO(@ansley): Support mixed list comprehensions
852        self._assert_raises(
853            template,
854            "List[torch.Tensor] | int",
855            lhs["list_comprehension_of_mixed"],
856            "Arguments for call are not valid",
857        )
858
859    def test_union_with_dict_assignment(self):
860        template = dedent(
861            """
862            def fn():
863                x: {ann} = {lhs}
864                if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
865                    x["foo"] = torch.tensor(3)
866                return x
867        """
868        )
869
870        lhs = {
871            "dict_literal_empty": "{}",
872            "dict_literal_of_str_tensor": '{"foo" : torch.arange(3), "bar" : torch.arange(5)}',
873            "dict_literal_of_str_int": '{"foo" : 1, "bar" : 2}',
874            "dict_literal_of_mixed": '{"foo" : torch.arange(3), "bar" : 2}',
875            "dict_comprehension_of_str_tensor": '{x : torch.add(y, 1) for x, y in \
876                    zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])}',
877            "dict_comprehension_of_str_int": '{x : torch.add(y, 1) for x, y in \
878                    zip(["foo", "bar"], [1, 2]}',
879            "dict_comprehension_of_mixed": '{x : torch.add(y, 1) for x, y in \
880                    zip(["foo", "bar"], [torch.arange(3), 2])}',
881            "dict_keyword": "dict(foo=torch.arange(3), baz=torch.arange(5))",
882            "dict_keyword_with_iterable": 'dict([("foo", torch.arange(3)), ("bar", torch.arange(5))])',
883            "dict_keyword_with_empty_iterable": "dict([])",
884            "dict_keyword_with_internal_aggregate_function": 'dict(zip(["foo", "bar"], [torch.arange(3), torch.arange(5)])',
885            "dict_keyword_with_mapping": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)})',
886            "dict_keyword_with_mapping_and_kwargs": 'dict({"foo" : torch.arange(3), "bar" : torch.arange(5)}, baz=torch.arange(7))',
887        }
888
889        """
890        Dict[str, torch.Tensor] | Dict[str, int]
891        """
892        self._assert_raises(
893            template,
894            "List[str] | List[torch.Tensor]",
895            lhs["dict_literal_empty"],
896            "Expected an Union type annotation with an " "inner Dict type",
897        )
898
899        self._assert_passes(
900            template,
901            "Dict[str, torch.Tensor] | Dict[str, int]",
902            lhs["dict_literal_of_str_tensor"],
903        )
904
905        self._assert_passes(
906            template,
907            "Dict[str, torch.Tensor] | Dict[str, int]",
908            lhs["dict_literal_of_str_int"],
909        )
910
911        self._assert_raises(
912            template,
913            "Dict[str, torch.Tensor] | Dict[str, int]",
914            lhs["dict_literal_of_mixed"],
915            "none of those dict types can hold the "
916            "types of the given keys and values",
917        )
918
919        # TODO: String frontend does not support tuple unpacking
920        # https://github.com/pytorch/pytorch/issues/64096
921        # self._assert_passes(template, "Dict[str, torch.Tensor] | Dict[str, int]",
922        #              lhs["dict_comprehension_of_str_tensor"])
923
924        # self._assert_passes(template, "Dict[str, torch.Tensor] | Dict[str, int]",
925        #              lhs["dict_comprehension_of_str_int"])
926
927        # self._assert_raises(template, "Dict[str, torch.Tensor] | Dict[str, int]",
928        #              lhs["dict_comprehension_of_mixed"],
929        #              "foobar")
930
931        # self._assert_passes(template,
932        #                    "Dict[str, torch.Tensor] | Dict[str, int]",
933        #                    lhs["dict_keyword_with_internal_aggregate_function"])
934
935        # TODO(@ansley): Follow-up project needed for full type
936        # inference with dict keyword (supported for dict comprehension
937        # and dict literal already; should not be a blocker for anyone)
938        self._assert_raises(
939            template,
940            "Dict[str, torch.Tensor] | Dict[str, int]",
941            lhs["dict_keyword"],
942            "full type inference is not yet supported",
943        )
944
945        self._assert_raises(
946            template,
947            "Dict[str, torch.Tensor] | Dict[str, int]",
948            lhs["dict_keyword_with_iterable"],
949            "full type inference is not yet supported",
950        )
951
952        self._assert_raises(
953            template,
954            "Dict[str, torch.Tensor] | Dict[str, int]",
955            lhs["dict_keyword_with_empty_iterable"],
956            "full type inference is not yet supported",
957        )
958
959        self._assert_raises(
960            template,
961            "Dict[str, torch.Tensor] | Dict[str, int]",
962            lhs["dict_keyword_with_mapping"],
963            "full type inference is not yet supported",
964        )
965
966        self._assert_raises(
967            template,
968            "Dict[str, torch.Tensor] | Dict[str, int]",
969            lhs["dict_keyword_with_mapping_and_kwargs"],
970            "full type inference is not yet supported",
971        )
972
973        """
974        int | torch.Tensor
975        """
976        self._assert_raises(
977            template,
978            "int | torch.Tensor",
979            lhs["dict_literal_empty"],
980            "Expected an Union type annotation with " "an inner Dict type",
981        )
982
983        self._assert_raises(
984            template,
985            "int | torch.Tensor",
986            lhs["dict_literal_of_str_tensor"],
987            "Expected an Union type annotation with " "an inner Dict type",
988        )
989
990        # See above--string frontend does not support tuple unpacking
991        # self._assert_raises(template, "int | torch.Tensor",
992        #              lhs["dict_comprehension_of_tensor"],
993        #              "foobar")
994
995        """
996        Dict[str, torch.Tensor] | int
997        """
998        self._assert_passes(
999            template, "Dict[str, torch.Tensor] | int", lhs["dict_literal_empty"]
1000        )
1001
1002        self._assert_passes(
1003            template, "Dict[str, torch.Tensor] | int", lhs["dict_literal_of_str_tensor"]
1004        )
1005
1006        self._assert_raises(
1007            template,
1008            "Dict[str, torch.Tensor] | int",
1009            lhs["dict_literal_of_str_int"],
1010            "Type annotation was inferred to be "
1011            r"`Dict\[str, Tensor\]`, but the type of "
1012            "values given by the dict literal is",
1013        )
1014
1015        self._assert_raises(
1016            template,
1017            "Dict[str, torch.Tensor] | int",
1018            lhs["dict_literal_of_mixed"],
1019            "Type annotation was inferred to be "
1020            r"`Dict\[str, Tensor\]`, but the type of "
1021            "values given by the dict literal is",
1022        )
1023
1024        self._assert_passes(
1025            template, "Dict[str, torch.Tensor] | int", lhs["dict_keyword"]
1026        )
1027
1028        self._assert_passes(
1029            template, "Dict[str, torch.Tensor] | int", lhs["dict_keyword_with_iterable"]
1030        )
1031
1032        self._assert_passes(
1033            template,
1034            "Dict[str, torch.Tensor] | int",
1035            lhs["dict_keyword_with_empty_iterable"],
1036        )
1037
1038        self._assert_passes(
1039            template, "Dict[str, torch.Tensor] | int", lhs["dict_keyword_with_mapping"]
1040        )
1041
1042        self._assert_passes(
1043            template,
1044            "Dict[str, torch.Tensor] | int",
1045            lhs["dict_keyword_with_mapping_and_kwargs"],
1046        )
1047
1048        # See above--string frontend does not support tuple unpacking
1049        # self._assert_passes(template,
1050        #                    "Dict[str, torch.Tensor] | int",
1051        #                    lhs["dict_keyword_with_internal_aggregate_function"])
1052        #
1053        # self._assert_passes(template,
1054        #                    "Dict[str, torch.Tensor] | int",
1055        #                    lhs["dict_comprehension_of_str_tensor"])
1056
1057        # self._assert_raises(template,
1058        #                    "Dict[str, torch.Tensor] | int",
1059        #                    lhs["dict_comprehension_of_str_int"],
1060        #                    "foobar")
1061
1062        # self._assert_raises(template,
1063        #                    "Dict[str, torch.Tensor] | int",
1064        #                    lhs["dict_comprehension_of_mixed"],
1065        #                    "foobar")
1066