xref: /aosp_15_r20/external/pytorch/test/jit/test_list_dict.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import inspect
4import os
5import sys
6import types
7import unittest
8from collections import defaultdict, OrderedDict
9from textwrap import dedent
10from typing import Any, Dict, List, NamedTuple, Optional, Tuple
11
12import torch
13import torch.nn as nn
14from torch import Tensor
15from torch.testing import FileCheck
16
17
18# Make the helper files in test/ importable
19pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
20sys.path.append(pytorch_test_dir)
21from torch.testing._internal.common_utils import skipIfTorchDynamo, TEST_CUDA
22from torch.testing._internal.jit_utils import JitTestCase, make_global
23
24
25if __name__ == "__main__":
26    raise RuntimeError(
27        "This test file is not meant to be run directly, use:\n\n"
28        "\tpython test/test_jit.py TESTNAME\n\n"
29        "instead."
30    )
31
32
33class TestList(JitTestCase):
34    def test_list_bool_conversion(self):
35        def if_predicate(l: List[int]):
36            if l:
37                s = 0
38                for n in l:
39                    s += n
40
41                return s
42            else:
43                return -1
44
45        self.checkScript(if_predicate, ([1, 2, 3],))
46        self.checkScript(if_predicate, ([],))
47
48        def while_predicate(l: List[int]):
49            s = 0
50
51            while l:
52                s += l.pop()
53
54        self.checkScript(while_predicate, ([1, 2, 3],))
55        self.checkScript(while_predicate, ([],))
56
57        def ternary_predicate(l: List[int]):
58            return "non-empty" if l else "empty"
59
60        self.checkScript(ternary_predicate, ([1, 2, 3],))
61        self.checkScript(ternary_predicate, ([],))
62
63    def test_in_check(self):
64        def int_in(x: List[int]) -> bool:
65            return 2 in x
66
67        self.checkScript(int_in, ([1, 2, 3],))
68        self.checkScript(int_in, ([1, 3, 3],))
69
70        def float_in(x: List[float]) -> bool:
71            return 2.0 in x
72
73        self.checkScript(float_in, ([1.0, 2.0, 3.0],))
74        self.checkScript(float_in, ([1.0, 3.0, 3.0],))
75
76        def str_in(x: List[str]) -> bool:
77            return "hi" in x
78
79        self.checkScript(str_in, (["not", "here"],))
80        self.checkScript(str_in, (["hi", "bye"],))
81        self.checkScript(str_in, ([],))
82
83    def test_list_literal(self):
84        def reassign():
85            x = [1]
86            if 1 == 1:
87                x = [2, 3]
88            return
89
90        self.checkScript(reassign, (), optimize=False)
91
92        def reassign_arity_change():
93            x = [1]
94            if 1 == 1:
95                x = [1, 2, 3]
96            return
97
98        self.checkScript(reassign_arity_change, (), optimize=False)
99
100        def reassign_from_empty_literal():
101            x = []
102            if 1 == 1:
103                x = [1, 2, 3]
104            return
105
106        with self.assertRaisesRegexWithHighlight(
107            RuntimeError, r"previously had type List\[Tensor\]", "x"
108        ):
109            self.checkScript(reassign_from_empty_literal, (), optimize=False)
110
111        def reassign_from_empty_builtin():
112            x = torch.jit.annotate(List[int], [])
113            if 1 == 1:
114                x = [1, 2, 3]
115            y = torch.jit.annotate(List[float], [])
116            if 1 == 1:
117                y = [1.0, 2.0, 3.0]
118            z = []
119            if 1 == 1:
120                z = [torch.randn([1])]
121            return
122
123        self.checkScript(reassign_from_empty_builtin, (), optimize=False)
124
125        def reassign_bad_type():
126            x = [1]
127            if 1 == 1:
128                x = [1.0]
129            return
130
131        with self.assertRaisesRegexWithHighlight(
132            RuntimeError, "previously had type", "x"
133        ):
134            self.checkScript(reassign_bad_type, (), optimize=False)
135
136        def reassign_nested():
137            x = torch.jit.annotate(List[int], [])
138            if 1 == 1:
139                x = [1, 2, 3]
140                if 1 == 1:
141                    x = [1.0]
142            return
143
144        with self.assertRaisesRegexWithHighlight(
145            RuntimeError, "previously had type", "x"
146        ):
147            self.checkScript(reassign_nested, (), optimize=False)
148
149    def test_list_variance(self):
150        """
151        `List[T1]` is not a subtype of `List[T2]`, even if `T1` is a
152        subtype of `T2`. However, if we have a temporary list object
153        (that is, a list comprehension or a list literal) on the rhs of
154        an assignment statement, we want to ignore the inferred type of
155        the rhs if we can prove that: 1) both the lhs and the rhs are
156        lists, and 2) the inner type of the lhs list is a subtype of the
157        inner type of the rhs list.
158
159        # This should pass
160        x: List[Optional[int]] = [None, None, None]
161
162        # This should fail
163        y: List[None] = [None, None, None]
164        x: List[Optional[int]] = y
165        """
166
167        def test_listliteral_is_typed_from_annotation():
168            x: List[Optional[int]] = [None, None, None]
169            return x
170
171        self.checkScript(test_listliteral_is_typed_from_annotation, ())
172
173        def test_listcomprehension_is_typed_from_annotation():
174            x: List[Optional[int]] = [None for _ in range(3)]
175            return x
176
177        self.checkScript(test_listcomprehension_is_typed_from_annotation, ())
178
179        def test_lists_with_different_internal_types_are_invariant(self):
180            x: List[int] = [1, 2, 3]
181            y: List[Optional[int]] = x
182            return x
183
184        with self.assertRaisesRegex(
185            RuntimeError,
186            "Variable 'y' is "
187            "annotated with type "
188            r"List\[Optional\[int\]\] but is "
189            "being assigned to a value of type "
190            r"List\[int\]",
191        ):
192            torch.jit.script(test_lists_with_different_internal_types_are_invariant)
193
194        def test_lists_with_different_internal_types_are_invariant_recursive(self):
195            x: List[List[int]] = [[1, 2], [3]]
196            y: List[List[Optional[int]]] = x
197            return x
198
199        with self.assertRaisesRegex(
200            RuntimeError,
201            "Variable 'y' is "
202            "annotated with type "
203            r"List\[List\[Optional\[int\]\]\] "
204            "but is being assigned to a value "
205            r"of type List\[List\[int\]\]",
206        ):
207            torch.jit.script(
208                test_lists_with_different_internal_types_are_invariant_recursive
209            )
210
211    def test_del(self):
212        def inputs():
213            return [1, 2, 3, 4]
214
215        def fn(x: List[int]) -> List[int]:
216            del x[1]
217            return x
218
219        python_out = fn(inputs())
220        # checkScript reuses the same object, but here it's being mutated so do
221        # it manually
222        cu = torch.jit.CompilationUnit()
223        cu.define(dedent(inspect.getsource(fn)))
224        self.assertEqual(cu.fn(inputs()), python_out)
225        self.assertEqual(torch.jit.script(fn)(inputs()), python_out)
226
227        @torch.jit.script
228        def fn2(x: List[int]) -> List[int]:
229            del x[100]
230            return x
231
232        with self.assertRaisesRegexWithHighlight(
233            RuntimeError, "out of range", "x[100]"
234        ):
235            fn2([])
236
237        with self.assertRaisesRegexWithHighlight(
238            RuntimeError, "deletion at a single index", "x[1:3]"
239        ):
240
241            @torch.jit.script
242            def fn(x: List[int]) -> List[int]:
243                del x[1:3]
244                return x
245
246    def test_list_keyword(self):
247        def foo():
248            return (
249                list([1, 2, 3]),  # noqa: C410
250                list(("a", "b")),  # noqa: C410
251                list(range(5)),
252                list("abcdefg"),
253            )
254
255        self.checkScript(foo, ())
256
257        def foo2():
258            x: List[int] = list()  # noqa: C408
259            x.append(1)
260            return (x,)
261
262        self.checkScript(foo2, ())
263
264        def foo3():
265            return list(list("abc"))  # noqa: C414
266
267        self.checkScript(foo3, ())
268        FileCheck().check_count("aten::list", 2, exactly=True).run(
269            torch.jit.script(foo3).graph
270        )
271
272    def test_dict_keyword_with_kwargs(self):
273        def fn():
274            return dict(foo=1, bar=2, baz=3)
275
276        self.checkScript(fn, ())
277
278    def test_dict_keyword_with_kwargs_using_container_values(self):
279        def fn():
280            return dict(foo=[1, 2, 3], bar=[4, 5, 6], baz=[7, 8, 9])
281
282        self.checkScript(fn, ())
283
284    def test_dict_keyword_with_iterable(self):
285        def fn():
286            return dict([("foo", 1), ("bar", 2), ("baz", 3)])  # noqa: C406
287
288        self.checkScript(fn, ())
289
290    def test_dict_keyword_with_empty_iterable(self):
291        def fn():
292            return dict([])  # noqa: C406
293
294        self.checkScript(fn, ())
295
296    def test_dict_keyword_with_internal_aggregate_function(self):
297        def fn():
298            return dict(zip(["foo", "baz", "bar"], [1, 2, 3]))
299
300        self.checkScript(fn, ())
301
302    def test_dict_keyword_with_mapping(self):
303        def fn():
304            return {"foo": 1, "bar": 2, "baz": 3}
305
306        self.checkScript(fn, ())
307
308    def test_dict_keyword_with_mapping_and_kwargs(self):
309        def fn():
310            return dict({"foo": 1, "bar": 2}, baz=3)
311
312        self.checkScript(fn, ())
313
314    def test_dict_keyword_with_dict_comprehension(self):
315        def fn():
316            return {i: chr(i + 65) for i in range(4)}
317
318        self.checkScript(fn, ())
319
320    def test_dict_keyword_with_dict_comprehension_and_kwargs(self):
321        def fn():
322            return dict({chr(65 + i): i for i in range(4)}, foo=2)
323
324        self.checkScript(fn, ())
325
326    def test_dict_keyword_with_empty_dict_comprehension(self):
327        def fn():
328            return {}
329
330        self.checkScript(fn, ())
331
332    def test_dict_keyword_is_correctly_typed(self):
333        def fn():
334            x: Dict[str, int] = dict()  # noqa: C408
335            x["foo"] = 1
336            return x
337
338        self.checkScript(fn, ())
339
340    def test_dict_keyword_with_mismatched_annotations(self):
341        err_msg = (
342            r"Dict type annotation `Dict\[int, str\]` did not "
343            "match the type of an actual key type `str`"
344        )
345        with self.assertRaisesRegex(RuntimeError, err_msg):
346
347            @torch.jit.script
348            def fn():
349                x: Dict[int, str] = dict(  # noqa: C406
350                    [("foo", 1), ("bar", 2), ("baz", 3)]
351                )
352                return x
353
354    def test_dict_keyword_with_nested_call(self):
355        def fn():
356            return dict(dict(foo=1, bar=2, baz=3))
357
358        self.checkScript(fn, ())
359
360    def test_dict_keyword_with_previously_declared_variable(self):
361        def fn():
362            d = {"foo": 1, "bar": 2}
363            return dict(d)
364
365        self.checkScript(fn, ())
366
367    def test_dict_keyword_with_previously_declared_variable_and_kwargs(self):
368        def fn():
369            d = {"foo": 1, "bar": 2}
370            return dict(d, baz=3)
371
372        self.checkScript(fn, ())
373
374    def test_min_bool_list(self):
375        def jit_min_list(a: List[bool], b: List[bool]) -> List[bool]:
376            return min(a, b)
377
378        self.checkScript(jit_min_list, ([True, False], [False, True]))
379
380    def test_min_max_list(self):
381        def jit_min_list(a: List[int], b: List[int]) -> List[int]:
382            return min(a, b)
383
384        def jit_min_list_float(a: List[float], b: List[float]) -> List[float]:
385            return min(a, b)
386
387        def jit_min_list_bool(a: List[bool], b: List[bool]) -> List[bool]:
388            return min(a, b)
389
390        def run_tests(func, a, b):
391            for t in zip(a, b):
392                self.checkScript(func, t)
393
394        args_left_int = [[1, 8, 8], [2, 1, 1], [], [2], [1], [1, 2, 3]]
395        args_right_int = [[2, 1, 1], [1, 8, 8], [], [1], [], [1, 2]]
396        run_tests(jit_min_list, args_left_int, args_right_int)
397
398        args_left_float = [
399            [1.0, 8.0, 8.0],
400            [2.0, 1.0, 1.0],
401            [],
402            [2.0],
403            [1.0],
404            [1.0, 2.0, 3.0],
405        ]
406        args_right_float = [[2.0, 1.0, 1.0], [1.0, 8.0, 8.0], [], [1.0], [], [1.0, 2.0]]
407        run_tests(jit_min_list_float, args_left_float, args_right_float)
408
409        args_left_bool = [
410            [],
411            [],
412            [],
413            [False],
414            [True],
415            [False, True],
416            [True, True],
417            [False, False, False],
418            [False, False, True],
419        ]
420        args_right_bool = [
421            [],
422            [False],
423            [True],
424            [True],
425            [False],
426            [True, True],
427            [False, True],
428            [False, False, True],
429            [False, False, False],
430        ]
431        run_tests(jit_min_list_bool, args_left_bool, args_right_bool)
432
433        def jit_max_list(a: List[int], b: List[int]) -> List[int]:
434            return max(a, b)
435
436        def jit_max_list_float(a: List[float], b: List[float]) -> List[float]:
437            return max(a, b)
438
439        def jit_max_list_bool(a: List[bool], b: List[bool]) -> List[bool]:
440            return max(a, b)
441
442        args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]]
443        args_right_int = [[8, 1, 1], [1, 8, 8], [], [2], [1], [1, 2, 3]]
444        run_tests(jit_max_list, args_left_int, args_right_int)
445
446        args_left_float = [[1.0, 8.0, 8.0], [8.0, 1.0, 1.0], [], [1.0], [], [1.0, 2.0]]
447        args_right_float = [
448            [8.0, 1.0, 1.0],
449            [1.0, 8.0, 8.0],
450            [],
451            [2.0],
452            [1.0],
453            [1.0, 2.0, 3.0],
454        ]
455        run_tests(jit_max_list_float, args_left_float, args_right_float)
456
457        run_tests(jit_max_list_bool, args_left_bool, args_right_bool)
458
459    def test_list_gather(self):
460        def index():
461            a = [1, 2, 3]
462            return a[1]
463
464        self.checkScript(index, ())
465
466        def negative_index():
467            a = [1, 2, 3]
468            return a[-1]
469
470        self.checkScript(negative_index, ())
471
472        def bad_index():
473            a = [1, 2, 3]
474            return a[4]
475
476        self.checkScriptRaisesRegex(bad_index, (), Exception, "list index out of range")
477
478        def bad_negative_index():
479            a = [1, 2, 3]
480            return a[-5]
481
482        self.checkScriptRaisesRegex(
483            bad_negative_index, (), Exception, "list index out of range"
484        )
485
486    def test_list_len(self):
487        def func():
488            a = [1, 2, 3]
489            return len(a) == 3
490
491        self.checkScript(func, ())
492
493        def func2():
494            a = []
495            return len(a) == 0
496
497        self.checkScript(func2, ())
498
499    @skipIfTorchDynamo(
500        "TorchDynamo fails to raise on this checkScriptRaisesRegex, because we trace it properly now"
501    )
502    def test_list_ops(self):
503        def test_equality():
504            a = [1, 2, 3]
505            b = [1, 2, 3]
506            return a == b
507
508        self.checkScript(test_equality, (), optimize=True)
509
510        def test_equality_str():
511            a = ["foo", "bar"]
512            b = ["foo", "bar"]
513            return a == b
514
515        self.checkScript(test_equality_str, (), optimize=True)
516
517        def test_inequality():
518            a = [1, 2, 3]
519            b = [1, 2, 3]
520            return a != b
521
522        self.checkScript(test_inequality, (), optimize=True)
523
524        def test_inequality_str():
525            a = ["foo", "bar"]
526            b = ["foo", "bar", "food"]
527            return a != b
528
529        self.checkScript(test_inequality_str, (), optimize=True)
530
531        def test_non_equality():
532            a = [1, 2, 3]
533            b = [3]
534            return a == b
535
536        self.checkScript(test_non_equality, (), optimize=True)
537
538        def test_non_inequality():
539            a = [1, 2, 3]
540            b = [3]
541            return a != b
542
543        self.checkScript(test_non_equality, (), optimize=True)
544
545        def test_list_equality_as_cond():
546            a = [1, 2, 3]
547            b = [3]
548            if a == b:
549                c = 1
550            else:
551                c = 2
552            return c
553
554        self.checkScript(test_list_equality_as_cond, (), optimize=True)
555
556        def test_list_add():
557            a = [1, 2, 3]
558            b = [2]
559            c = a + b
560            return c == [1, 2, 3, 2]
561
562        self.checkScript(test_list_add, (), optimize=True)
563
564        def test_list_add_empty():
565            a = [1, 2, 3]
566            b = torch.jit.annotate(List[int], [])
567            c = a + b
568            return c == [1, 2, 3]
569
570        self.checkScript(test_list_add_empty, (), optimize=True)
571
572        def test_tensor_list_equality():
573            t1 = torch.ones([1, 1])
574            t2 = torch.ones([1, 1])
575            x = [t1, t2]
576            y = [t2, t1]
577            return x == y
578
579        self.checkScript(test_tensor_list_equality, (), optimize=True)
580
581        def test_invalid_list_equality():
582            t1 = torch.ones([2, 2])
583            t2 = torch.ones([2, 2])
584            x = [t1, t2]
585            y = [t2, t1]
586            # will throw since the tensors have more than one element
587            return x == y
588
589        self.checkScriptRaisesRegex(
590            test_invalid_list_equality, (), RuntimeError, "Boolean value of Tensor"
591        )
592
593    def test_list_sort(self):
594        template = dedent(
595            """
596        def func():
597            li_1 = {list_create}
598            li_2 = {list_create}
599            li_3 = {list_create}
600            li_1.sort()
601            li_2.sort(reverse=True)
602            li_4 = sorted(li_3)
603            return li_1, li_2, li_3, li_4
604        """
605        )
606
607        lists = [
608            "[]",
609            "[1, 3, 2]",
610            "[True, False, True]",
611            "[1.2, .2, 3.2]",
612            "[torch.tensor(1.0), torch.tensor(0.2), torch.tensor(0.5)]",
613            "[torch.tensor(5), torch.tensor(-2), torch.tensor(4)]",
614        ]
615        for li in lists:
616            code = template.format(list_create=li)
617            scope = {}
618            exec(code, globals(), scope)
619            cu = torch.jit.CompilationUnit(code)
620            t1 = cu.func()
621            t2 = scope["func"]()
622            self.assertEqual(t1, t2)
623
624        def test_fail(x: List[Tensor]) -> List[Tensor]:
625            x.sort()
626            return x
627
628        self.checkScriptRaisesRegex(
629            test_fail,
630            (([torch.zeros([2]), torch.zeros([2])],)),
631            Exception,
632            "Boolean value of Tensor with more than one value",
633        )
634
635        @torch.jit.script
636        def test_mutation():
637            a = [1, 2, 3]
638            a.sort()
639            return a
640
641        test_mutation()
642        FileCheck().check("aten::sort").run(test_mutation.graph_for())
643
644        def test_sorted_copy():
645            a = [torch.tensor(2), torch.tensor(0), torch.tensor(1)]
646            b = sorted(a)
647            a[0] = torch.tensor(10)
648            return a, b
649
650        self.checkScript(test_sorted_copy, ())
651
652    def test_list_slice(self):
653        def test_regular_slice():
654            a = [0, 1, 2, 3, 4]
655            return a[2:3] == [2]
656
657        self.checkScript(test_regular_slice, ())
658
659        def test_open_ended_slice():
660            a = [0, 1, 2, 3, 4]
661            return a[2:] == [2, 3, 4]
662
663        self.checkScript(test_open_ended_slice, ())
664
665        def test_open_ended_slice2():
666            a = [0, 1, 2, 3, 4]
667            return a[:2] == [0, 1]
668
669        self.checkScript(test_open_ended_slice2, ())
670
671        def test_negative_slice():
672            a = [0, 1, 2, 3, 4]
673            return a[:-1] == [0, 1, 2, 3]
674
675        self.checkScript(test_negative_slice, ())
676
677        def test_negative_slice2():
678            a = [0, 1, 2, 3, 4]
679            return a[-3:-1] == [2, 3]
680
681        self.checkScript(test_negative_slice2, ())
682
683        def test_backward_slice():
684            a = [0, 1, 2, 3, 4]
685            return a[3:2] == torch.jit.annotate(List[int], [])
686
687        self.checkScript(test_backward_slice, ())
688
689        def test_over_slice():
690            a = [0, 1, 2, 3, 4]
691            return a[3:10] == [3, 4]
692
693        self.checkScript(test_backward_slice, ())
694
695    def test_slice_index(self):
696        a = torch.tensor(
697            [
698                [[1, 11], [2, 22]],
699                [[3, 33], [4, 44]],
700                [[5, 55], [6, 66]],
701            ]
702        )
703
704        def test_index_slice1(x):
705            x = x[:, :, [0, 1]]
706            return x
707
708        self.checkScript(test_index_slice1, (a,))
709
710        def test_index_slice2(x):
711            x = x[[2, 1, 0], :, :]
712            return x
713
714        self.checkScript(test_index_slice2, (a,))
715
716        def test_index_slice3(x):
717            x = x[[0, 1], :, [1]]
718            return x
719
720        self.checkScript(test_index_slice3, (a,))
721
722        def test_index_slice_empty_list(x):
723            empty_list: List[int] = []
724            x = x[empty_list, :, :]
725            return x
726
727        self.checkScript(test_index_slice_empty_list, (a,))
728
729        def test_index_slice_out_of_bounds_index(x):
730            x = x[[4], :, :]
731            return x
732
733        with self.assertRaisesRegexWithHighlight(
734            RuntimeError,
735            "index 4 is out of bounds for dimension 0 with size 3",
736            "x[[4], :, :]",
737        ):
738            self.checkScript(test_index_slice_out_of_bounds_index, (a,))
739
740    def test_mutable_list_append(self):
741        def test_append():
742            a = [0, 1]
743            a.append(2)
744            a.append(3)
745            return a == [0, 1, 2, 3]
746
747        self.checkScript(test_append, ())
748
749    def test_comprehensions_basic(self):
750        def comp(l: List[int]) -> List[int]:
751            n = [x * 3 for x in l]
752            return n
753
754        comp([1, 2, 3])
755        self.checkScript(comp, ([1, 2, 3],))
756
757    def test_comprehensions_basic_float(self):
758        def comp(l: List[float]) -> List[float]:
759            n = [x * 3 for x in l]
760            return n
761
762        self.checkScript(comp, ([1.0, 2.0, 3.0],))
763
764    def test_comprehensions_two_comps(self):
765        @torch.jit.script
766        def comp(l1: List[int], l2: List[int]) -> List[int]:
767            n = [x * 3 for x in l1]
768            n2 = [x + 2 for x in l2]
769            return n + n2
770
771        self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7])
772
773    def test_comprehension_out_type_not_in_type(self):
774        def list_cast() -> int:
775            li = [int(i) for i in [torch.tensor(0), torch.tensor(1), torch.tensor(2)]]
776            return li[0] + li[1] + li[2]
777
778        self.checkScript(list_cast, ())
779
780    def test_comprehension_iterable(self):
781        def test_func(fn, inputs):
782            self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs))
783
784        def foo(names: List[int], results: List[int]) -> List[Tuple[int, int]]:
785            return [(k + 5, v - 2) for k, v in zip(names, results)]
786
787        test_func(foo, ([1, 2, 4], [4, 7, 9]))
788        test_func(foo, ([5], [4, 7, 9]))
789
790        def fn(x: int) -> List[int]:
791            return [i for i in range(x)]  # noqa: C416
792
793        test_func(fn, (9,))
794        test_func(fn, (0,))
795        test_func(fn, (-1,))
796
797        def changes_type():
798            a = [float(i) for i in range(5)]
799            b = [float(i) for i in [1, 2, 3, 4]]
800            c = [(float(i), j) for i, j in enumerate([1, 2, 3, 8])]
801            return a, b, c
802
803        test_func(changes_type, ())
804
805        def test_zero_iter():
806            return [str(i) for i, j in zip("", "")]
807
808        test_func(test_zero_iter, ())
809
810    def test_mutable_list_append_2(self):
811        def test_append_2():
812            a = [0, 1]
813            a.append(2)
814            a = [1]
815            a.append(4)
816            return a == [1, 4]
817
818        self.checkScript(test_append_2, ())
819
820    def test_mutable_list_append_if(self):
821        def test_append_if():
822            a = [1]
823            if 1 == 1:
824                a.append(4)
825            return a == [1, 4]
826
827        self.checkScript(test_append_if, ())
828
829    def test_mutable_list_append_if_else(self):
830        def test_append_if_else():
831            a = [1]
832            if 1 == 2:
833                a.append(4)
834            else:
835                a.append(10)
836            return a == [1, 10]
837
838        self.checkScript(test_append_if_else, ())
839
840    def test_mutable_list_append_loop(self):
841        def test_append_loop():
842            a = torch.jit.annotate(List[int], [])
843            for i in range(5):
844                a.append(i)
845
846            return a == [0, 1, 2, 3, 4]
847
848        self.checkScript(test_append_loop, ())
849
850    def test_mutable_list_append_loop_if(self):
851        def test_append_loop_if():
852            a = torch.jit.annotate(List[int], [])
853            for i in range(5):
854                if i > 3:
855                    a.append(i)
856                else:
857                    a.append(0)
858
859            return a == [0, 0, 0, 0, 4]
860
861        self.checkScript(test_append_loop_if, ())
862
863    def test_mutable_list_nested_loop(self):
864        def test_nested_loop():
865            a = torch.jit.annotate(List[int], [])
866            for i in range(2):
867                for j in range(2):
868                    a.append(i + j)
869
870            return a == [0, 1, 1, 2]
871
872        self.checkScript(test_nested_loop, ())
873
874    def test_mutable_list_function_inline(self):
875        @torch.jit.script
876        def bar(y: List[int]) -> None:
877            y.append(4)
878
879        @torch.jit.script
880        def foo():
881            x = [1, 2, 3]
882            bar(x)
883            return x
884
885        self.assertEqual(foo(), [1, 2, 3, 4])
886
887    def test_mutable_list_reverse_empty(self):
888        def test_reverse_empty():
889            a = []
890            a.reverse()
891
892            return a == []
893
894        self.checkScript(test_reverse_empty, ())
895
896    def test_mutable_list_reverse(self):
897        def test_reverse():
898            a = [1, 2, 3, 4]
899            a.reverse()
900
901            return a == [4, 3, 2, 1]
902
903        self.checkScript(test_reverse, ())
904
905    def test_mutable_tensor_list_reverse(self):
906        def test_tensor_reverse():
907            a = [torch.tensor(1), torch.tensor(2)]
908            a.reverse()
909
910            return a == [torch.tensor(2), torch.tensor(1)]
911
912        self.checkScript(test_tensor_reverse, ())
913
914    def test_mutable_list_pop_empty(self):
915        @torch.jit.script
916        def test_pop_empty():
917            a = torch.jit.annotate(List[int], [])
918            return a.pop()
919
920        with self.assertRaisesRegexWithHighlight(
921            RuntimeError, "pop from empty list", "a.pop"
922        ):
923            test_pop_empty()
924
925    def test_mutable_list_pop(self):
926        def test_pop():
927            a = [1, 2, 3, 4]
928            b = a.pop()
929
930            return b == 4
931
932        self.checkScript(test_pop, ())
933
934    def test_mutable_list_pop2(self):
935        def test_pop2():
936            a = [1, 2, 3, 4]
937            b = a.pop()
938
939            return len(a) == 3
940
941        self.checkScript(test_pop2, ())
942
943    def test_mutable_list_pop_at(self):
944        def test_pop_at():
945            a = [1, 2, 3, 4]
946            b = a.pop(1)
947
948            return b == 2
949
950        self.checkScript(test_pop_at, ())
951
952    def test_mutable_list_pop_at2(self):
953        def test_pop_at2():
954            a = [1, 2, 3, 4]
955            b = a.pop(1)
956
957            return len(a) == 3
958
959        self.checkScript(test_pop_at2, ())
960
961    def test_mutable_list_pop_at_negative(self):
962        def test_pop_at_negative():
963            a = [1, 2, 3, 4]
964            b = a.pop(-2)
965
966            return b == 3
967
968        self.checkScript(test_pop_at_negative, ())
969
970    def test_mutable_list_pop_at_negative2(self):
971        def test_pop_at_negative2():
972            a = [1, 2, 3, 4]
973            b = a.pop(-2)
974
975            return len(a) == 3
976
977        self.checkScript(test_pop_at_negative2, ())
978
979    def test_mutable_list_pop_slice(self):
980        def test_pop_slice():
981            a = [1, 2, 3, 4]
982            b = [1, 2, 3, 4]
983
984            a.pop()
985            b = b[:-1]
986
987            return a == b
988
989        self.checkScript(test_pop_slice, ())
990
991    def test_mutable_list_clear_empty(self):
992        def test_clear_empty():
993            a = torch.jit.annotate(List[int], [])
994            a.clear()
995
996            return len(a) == 0
997
998        self.checkScript(test_clear_empty, ())
999
1000    def test_mutable_list_clear(self):
1001        def test_clear():
1002            a = [1, 2, 3, 4]
1003            a.clear()
1004
1005            return len(a) == 0
1006
1007        self.checkScript(test_clear, ())
1008
1009    def test_mutable_list_insert(self):
1010        def test_list_insert():
1011            a = [1, 2, 3, 4]
1012            a.insert(2, 5)
1013
1014            return a == [1, 2, 5, 3, 4]
1015
1016        self.checkScript(test_list_insert, ())
1017
1018    def test_mutable_list_insert_negative(self):
1019        def test_list_insert_negative():
1020            a = [1, 2, 3, 4]
1021            a.insert(-1, 5)
1022
1023            return a == [1, 2, 3, 5, 4]
1024
1025        self.checkScript(test_list_insert_negative, ())
1026
1027    def test_mutable_list_insert_neg_out_of_bounds(self):
1028        def test_list_insert_neg_out_of_bounds():
1029            a = [1, 2, 3, 4]
1030            a.insert(-10, 5)
1031
1032            return a == [5, 1, 2, 3, 4]
1033
1034        self.checkScript(test_list_insert_neg_out_of_bounds, ())
1035
1036    def test_mutable_list_insert_out_of_bounds(self):
1037        def test_list_insert_out_of_bounds():
1038            a = [1, 2, 3, 4]
1039            a.insert(10, 5)
1040
1041            return a == [1, 2, 3, 4, 5]
1042
1043        self.checkScript(test_list_insert_out_of_bounds, ())
1044
1045    def test_mutable_list_remove_not_existing(self):
1046        @torch.jit.script
1047        def test_list_remove_not_existing():
1048            a = [1, 2, 3, 4]
1049            a.remove(5)
1050
1051            return a
1052
1053        with self.assertRaisesRegexWithHighlight(
1054            RuntimeError, "x not in list", "a.remove"
1055        ):
1056            test_list_remove_not_existing()
1057
1058    def test_mutable_list_remove(self):
1059        def test_list_remove():
1060            a = [1, 2, 3, 4]
1061            a.remove(3)
1062
1063            return a == [1, 2, 4]
1064
1065        self.checkScript(test_list_remove, ())
1066
1067        def test_str_list_remove():
1068            a = ["foo", "bar"]
1069            a.remove("foo")
1070
1071            return a == ["bar"]
1072
1073        self.checkScript(test_str_list_remove, ())
1074
1075    def test_list_index_not_existing(self):
1076        @torch.jit.script
1077        def list_index_not_existing():
1078            a = [4, 1, 3, 2]
1079            i = a.index(5)
1080
1081            return i
1082
1083        with self.assertRaisesRegexWithHighlight(
1084            RuntimeError, "'5' is not in list", "a.index"
1085        ):
1086            list_index_not_existing()
1087
1088    def test_list_index(self):
1089        def list_index():
1090            a = [4, 1, 3, 2]
1091            i = a.index(3)
1092
1093            return i == 2
1094
1095        self.checkScript(list_index, ())
1096
1097        def list_str_index():
1098            a = ["foo", "bar"]
1099            i = a.index("bar")
1100
1101            return i == 1
1102
1103        self.checkScript(list_str_index, ())
1104
1105    def test_tensor_list_index(self):
1106        def tensor_list_index():
1107            a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
1108            i = a.index(torch.tensor(3))
1109
1110            return i == 2
1111
1112        self.checkScript(tensor_list_index, ())
1113
1114    def test_tensor_list_index_not_existing(self):
1115        @torch.jit.script
1116        def tensor_list_index_not_existing():
1117            a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
1118            i = a.index(torch.tensor(5))
1119
1120            return i
1121
1122        with self.assertRaisesRegexWithHighlight(
1123            RuntimeError, "is not in list", "a.index"
1124        ):
1125            tensor_list_index_not_existing()
1126
1127    def test_list_count(self):
1128        def list_count():
1129            a = [4, 1, 4, 2, 4]
1130            i = a.count(4)
1131
1132            return i == 3
1133
1134        self.checkScript(list_count, ())
1135
1136        def list_str_count():
1137            a = ["foo", "bar", "foo"]
1138            i = a.count("foo")
1139
1140            return i == 2
1141
1142        self.checkScript(list_str_count, ())
1143
1144    def test_list_count_not_existing(self):
1145        def list_count_not_existing():
1146            a = [4, 1, 4, 2, 4]
1147            i = a.count(5)
1148
1149            return i == 0
1150
1151        self.checkScript(list_count_not_existing, ())
1152
1153    def test_tensor_list_count(self):
1154        def tensor_list_count():
1155            a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
1156            i = a.count(torch.tensor(4))
1157
1158            return i == 3
1159
1160        self.checkScript(tensor_list_count, ())
1161
1162    def test_tensor_list_count_not_existing(self):
1163        def tensor_list_count_not_existing():
1164            a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
1165            i = a.count(torch.tensor(5))
1166
1167            return i == 0
1168
1169        self.checkScript(tensor_list_count_not_existing, ())
1170
1171    def test_mutable_list_remove_tensor(self):
1172        def test_list_remove_tensor():
1173            a = [torch.ones(1), torch.zeros(1), torch.ones(2)]
1174            a.remove(torch.zeros(1))
1175
1176            return len(a) == 2
1177
1178        self.checkScript(test_list_remove_tensor, ())
1179
1180    def test_mutable_list_remove2(self):
1181        def test_list_remove2():
1182            a = [1]
1183            a.remove(1)
1184
1185            return len(a) == 0
1186
1187        self.checkScript(test_list_remove2, ())
1188
1189    def test_extend_list_mutable(self):
1190        @torch.jit.script
1191        def extend_list(a: List[Tensor], b: List[Tensor]) -> List[Tensor]:
1192            a.extend(b)
1193            return a
1194
1195        for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
1196            for r in [
1197                [],
1198                [torch.rand(2)],
1199                [torch.rand(2), torch.rand(2), torch.rand(2)],
1200            ]:
1201                self.assertEqual(extend_list(l, r), l + r)
1202
1203    def test_extend_list_immutable(self):
1204        @torch.jit.script
1205        def extend_list(a: List[int], b: List[int]) -> List[int]:
1206            a.extend(b)
1207            return a
1208
1209        for l in [[], [1], [1, 2, 3]]:
1210            for r in [[], [1], [1, 2, 3]]:
1211                self.assertEqual(extend_list(l, r), l + r)
1212
1213    def test_copy_list_mutable(self):
1214        @torch.jit.script
1215        def copy_list(a: List[Tensor]) -> List[Tensor]:
1216            return a.copy()
1217
1218        for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
1219            self.assertEqual(copy_list(l), l)
1220
1221    def test_copy_list_immutable(self):
1222        @torch.jit.script
1223        def copy_list(a: List[int]) -> List[int]:
1224            return a.copy()
1225
1226        for l in [[], [1], [1, 2, 3]]:
1227            self.assertEqual(copy_list(l), l)
1228
1229    def test_min_max_single_list(self):
1230        def min_intlist(li: List[int]) -> int:
1231            return min(li)
1232
1233        def max_intlist(li: List[int]) -> int:
1234            return max(li)
1235
1236        def min_boollist(li: List[bool]) -> bool:
1237            return min(li)
1238
1239        def max_boollist(li: List[bool]) -> bool:
1240            return max(li)
1241
1242        def min_floatlist(li: List[float]) -> float:
1243            return min(li)
1244
1245        def max_floatlist(li: List[float]) -> float:
1246            return max(li)
1247
1248        int_lists = [1], [2, 1, 2], [-3, 4, 2], [-2, -7, 1, 4], [2, 1, 0, 4], []
1249
1250        def check_list(fn, li):
1251            if len(li) == 0:
1252                self.checkScriptRaisesRegex(fn, (li,), Exception, "empty")
1253            else:
1254                self.checkScript(fn, (li,))
1255
1256        for int_list in int_lists:
1257            check_list(min_intlist, int_list)
1258            check_list(max_intlist, int_list)
1259
1260            bool_li = [bool(x) for x in int_list]
1261            check_list(min_boollist, bool_li)
1262            check_list(max_boollist, bool_li)
1263
1264            float_li = [float(x) for x in int_list]
1265            check_list(min_floatlist, float_li)
1266            check_list(max_floatlist, float_li)
1267
1268    def test_to_list(self):
1269        """Unit tests for Tensor.tolist() function."""
1270
1271        """
1272        Boolean dtype unit tests.
1273        """
1274
1275        def to_list_bool_0D(x: torch.Tensor) -> bool:
1276            li = torch.jit.annotate(bool, x.tolist())
1277            return li
1278
1279        def to_list_bool_1D(x: torch.Tensor) -> List[bool]:
1280            li = torch.jit.annotate(List[bool], x.tolist())
1281            return li
1282
1283        def to_list_bool_2D(x: torch.Tensor) -> List[List[bool]]:
1284            li = torch.jit.annotate(List[List[bool]], x.tolist())
1285            return li
1286
1287        def to_list_bool_3D(x: torch.Tensor) -> List[List[List[bool]]]:
1288            li = torch.jit.annotate(List[List[List[bool]]], x.tolist())
1289            return li
1290
1291        self.checkScript(to_list_bool_0D, (torch.tensor(False, dtype=torch.bool),))
1292        bool_input_1D = torch.tensor([True, False, True, False], dtype=torch.bool)
1293        self.checkScript(to_list_bool_1D, (bool_input_1D,))
1294        bool_input_2D = torch.tensor(
1295            [[True, True, False], [False, True, False]], dtype=torch.bool
1296        )
1297        self.checkScript(to_list_bool_2D, (bool_input_2D,))
1298        bool_input_3D = torch.tensor(
1299            [[[True, False], [False, True]], [[True, False], [False, False]]],
1300            dtype=torch.bool,
1301        )
1302        self.checkScript(to_list_bool_3D, (bool_input_3D,))
1303        bool_input_noncontiguous = torch.tensor(
1304            [[[True, False], [False, True]], [[True, False], [False, False]]],
1305            dtype=torch.bool,
1306        ).transpose(0, 1)
1307        self.checkScript(to_list_bool_3D, (bool_input_noncontiguous,))
1308
1309        """
1310        Int dtype unit tests.
1311        """
1312
1313        def to_list_int_0D(x: torch.Tensor) -> int:
1314            li = torch.jit.annotate(int, x.tolist())
1315            return li
1316
1317        def to_list_int_1D(x: torch.Tensor) -> List[int]:
1318            li = torch.jit.annotate(List[int], x.tolist())
1319            return li
1320
1321        def to_list_int_2D(x: torch.Tensor) -> List[List[int]]:
1322            li = torch.jit.annotate(List[List[int]], x.tolist())
1323            return li
1324
1325        def to_list_int_3D(x: torch.Tensor) -> List[List[List[int]]]:
1326            li = torch.jit.annotate(List[List[List[int]]], x.tolist())
1327            return li
1328
1329        self.checkScript(to_list_int_0D, (torch.tensor(1, dtype=torch.long),))
1330        int_input_1D = torch.tensor([1, 2, 3, 4], dtype=torch.long)
1331        self.checkScript(to_list_int_1D, (int_input_1D,))
1332        int_input_2D = torch.tensor([[1, 2, 3], [3, 4, 5]], dtype=torch.long)
1333        self.checkScript(to_list_int_2D, (int_input_2D,))
1334        int_input_3D = torch.tensor(
1335            [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long
1336        )
1337        self.checkScript(to_list_int_3D, (int_input_3D,))
1338        int_input_noncontiguous = torch.tensor(
1339            [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.long
1340        ).transpose(0, 1)
1341        self.checkScript(to_list_int_3D, (int_input_noncontiguous,))
1342
1343        """
1344        Float dtype unit tests.
1345        """
1346
1347        def to_list_float_0D(x: torch.Tensor) -> float:
1348            li = torch.jit.annotate(float, x.tolist())
1349            return li
1350
1351        def to_list_float_1D(x: torch.Tensor) -> List[float]:
1352            li = torch.jit.annotate(List[float], x.tolist())
1353            return li
1354
1355        def to_list_float_2D(x: torch.Tensor) -> List[List[float]]:
1356            li = torch.jit.annotate(List[List[float]], x.tolist())
1357            return li
1358
1359        def to_list_float_3D(x: torch.Tensor) -> List[List[List[float]]]:
1360            li = torch.jit.annotate(List[List[List[float]]], x.tolist())
1361            return li
1362
1363        # Test with torch.float dtype Tensors to check that they are converted to double automatically.
1364        self.checkScript(to_list_float_0D, (torch.randn(5, dtype=torch.float)[0],))
1365        self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.float),))
1366        self.checkScript(to_list_float_2D, (torch.randn(5, 6, dtype=torch.float),))
1367        self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.float),))
1368        self.checkScript(
1369            to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.float).transpose(0, 1),)
1370        )
1371
1372        self.checkScript(to_list_float_0D, (torch.randn(5, dtype=torch.double)[0],))
1373        self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.double),))
1374        self.checkScript(to_list_float_2D, (torch.randn(5, 6, dtype=torch.double),))
1375        self.checkScript(to_list_float_3D, (torch.randn(5, 6, 7, dtype=torch.double),))
1376        self.checkScript(
1377            to_list_float_3D,
1378            (torch.randn(5, 6, 7, dtype=torch.double).transpose(0, 1),),
1379        )
1380
1381        """
1382        Complex dtype unit tests.
1383        """
1384
1385        def to_list_complex_0D(x: torch.Tensor) -> complex:
1386            li = torch.jit.annotate(complex, x.tolist())
1387            return li
1388
1389        def to_list_complex_1D(x: torch.Tensor) -> List[complex]:
1390            li = torch.jit.annotate(List[complex], x.tolist())
1391            return li
1392
1393        def to_list_complex_2D(x: torch.Tensor) -> List[List[complex]]:
1394            li = torch.jit.annotate(List[List[complex]], x.tolist())
1395            return li
1396
1397        def to_list_complex_3D(x: torch.Tensor) -> List[List[List[complex]]]:
1398            li = torch.jit.annotate(List[List[List[complex]]], x.tolist())
1399            return li
1400
1401        # Test with torch.complex dtype Tensors to check that they are converted to double automatically.
1402        self.checkScript(to_list_complex_0D, (torch.randn(5, dtype=torch.cfloat)[0],))
1403        self.checkScript(to_list_complex_1D, (torch.randn(5, dtype=torch.cfloat),))
1404        self.checkScript(to_list_complex_2D, (torch.randn(5, 6, dtype=torch.cfloat),))
1405        self.checkScript(
1406            to_list_complex_3D, (torch.randn(5, 6, 7, dtype=torch.cfloat),)
1407        )
1408        self.checkScript(
1409            to_list_complex_3D,
1410            (torch.randn(5, 6, 7, dtype=torch.cfloat).transpose(0, 1),),
1411        )
1412
1413        self.checkScript(to_list_complex_0D, (torch.randn(5, dtype=torch.cdouble)[0],))
1414        self.checkScript(to_list_complex_1D, (torch.randn(5, dtype=torch.cdouble),))
1415        self.checkScript(to_list_complex_2D, (torch.randn(5, 6, dtype=torch.cdouble),))
1416        self.checkScript(
1417            to_list_complex_3D, (torch.randn(5, 6, 7, dtype=torch.cdouble),)
1418        )
1419        self.checkScript(
1420            to_list_complex_3D,
1421            (torch.randn(5, 6, 7, dtype=torch.cdouble).transpose(0, 1),),
1422        )
1423
1424        """
1425        Non-happy path tests:
1426            - missing type annotation
1427            - mismatch between type annotation and input
1428            - type annotation with unsupported type
1429            - type annotation with the wrong dimension
1430            - type annotation with scalar type that doesn't match the input scalar type
1431        """
1432
1433        def to_list_missing_type_annotation(x: torch.Tensor) -> List[float]:
1434            li = x.tolist()
1435            return li
1436
1437        def to_list_incorrect_type_annotation(x: torch.Tensor) -> List[float]:
1438            li = torch.jit.annotate(float, x.tolist())
1439            return li
1440
1441        def to_list_unsupported_type_annotation(x: torch.Tensor) -> List[float]:
1442            li = torch.jit.annotate(List[str], x.tolist())
1443            return li
1444
1445        def to_list_type_annotation_wrong_dim(x: torch.Tensor) -> List[List[float]]:
1446            li = torch.jit.annotate(List[List[float]], x.tolist())
1447            return li
1448
1449        def to_list_type_annotation_incorrect_scalar_type(
1450            x: torch.Tensor,
1451        ) -> List[float]:
1452            li = torch.jit.annotate(List[float], x.tolist())
1453            return li
1454
1455        with self.assertRaisesRegexWithHighlight(
1456            RuntimeError, r"Expected type hint for result of tolist()", "x.tolist("
1457        ):
1458            self.checkScript(to_list_missing_type_annotation, (torch.randn(5),))
1459
1460        with self.assertRaisesRegexWithHighlight(
1461            RuntimeError,
1462            r"Return value was annotated as having type List\[float\] but is actually of type float",
1463            "return li",
1464        ):
1465            self.checkScript(to_list_incorrect_type_annotation, (torch.randn(5),))
1466
1467        with self.assertRaisesRegex(
1468            RuntimeError, r"str is not one of the supported element types for tolist"
1469        ):
1470            self.checkScript(to_list_unsupported_type_annotation, (torch.randn(5),))
1471
1472        with self.assertRaisesRegex(
1473            RuntimeError,
1474            r"Output annotation list dimension and runtime tensor dimension must match",
1475        ):
1476            self.checkScript(
1477                to_list_type_annotation_wrong_dim, (torch.randn(5, dtype=torch.double),)
1478            )
1479
1480        with self.assertRaisesRegex(
1481            RuntimeError,
1482            r"Output annotation element type and runtime tensor element type must match",
1483        ):
1484            self.checkScript(
1485                to_list_type_annotation_incorrect_scalar_type,
1486                (torch.ones(5, dtype=torch.long),),
1487            )
1488
1489    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
1490    def test_to_list_gpu(self):
1491        """GPU tests for Tensor.tolist() function."""
1492
1493        def to_list_bool_1D(x: torch.Tensor) -> List[bool]:
1494            li = torch.jit.annotate(List[bool], x.tolist())
1495            return li
1496
1497        def to_list_int_1D(x: torch.Tensor) -> List[int]:
1498            li = torch.jit.annotate(List[int], x.tolist())
1499            return li
1500
1501        def to_list_float_1D(x: torch.Tensor) -> List[float]:
1502            li = torch.jit.annotate(List[float], x.tolist())
1503            return li
1504
1505        self.checkScript(
1506            to_list_bool_1D,
1507            (torch.tensor([True, False, True, False], dtype=torch.bool).cuda(),),
1508        )
1509        self.checkScript(
1510            to_list_int_1D, (torch.tensor([1, 2, 3, 4], dtype=torch.long).cuda(),)
1511        )
1512        self.checkScript(to_list_float_1D, (torch.randn(5, dtype=torch.double).cuda(),))
1513
1514    def test_no_element_type_annotation(self):
1515        def fn_with_comment(x: torch.Tensor) -> List:
1516            a: List = x.tolist()
1517            return a
1518
1519        def annotated_fn(x: torch.Tensor) -> List:
1520            a: List = x.tolist()
1521            return a
1522
1523        with self.assertRaisesRegex(
1524            RuntimeError, r"Attempted to use List without a contained type"
1525        ):
1526            cu = torch.jit.CompilationUnit()
1527            cu.define(dedent(inspect.getsource(fn_with_comment)))
1528
1529        with self.assertRaisesRegex(
1530            RuntimeError, r"Attempted to use List without a contained type"
1531        ):
1532            cu = torch.jit.CompilationUnit()
1533            cu.define(dedent(inspect.getsource(annotated_fn)))
1534
1535        with self.assertRaisesRegex(
1536            RuntimeError, r"Attempted to use List without a contained type"
1537        ):
1538            torch.jit.script(fn_with_comment)
1539
1540        with self.assertRaisesRegex(
1541            RuntimeError, r"Attempted to use List without a contained type"
1542        ):
1543            torch.jit.script(annotated_fn)
1544
1545    def test_list_none(self):
1546        with self.assertRaisesRegex(
1547            RuntimeError, "Can not create ListType with None type"
1548        ):
1549            x = torch._C.ListType(None)
1550
1551    def test_list_unification_hint(self):
1552        with self.assertRaisesRegex(
1553            RuntimeError, "Expected an annotation of type List"
1554        ):
1555
1556            @torch.jit.script
1557            def x():
1558                b: int = [2, 3]
1559                return b
1560
1561
1562class TestDict(JitTestCase):
1563    def dict(self):
1564        return {"a": torch.ones(1), "b": torch.ones(1) + 1, "c": torch.ones(1) + 2}
1565
1566    def dict2(self):
1567        return {
1568            "x": torch.ones(1) + 100,
1569            "y": torch.ones(1) + 101,
1570            "z": torch.ones(1) + 102,
1571        }
1572
1573    def dict_bool(self):
1574        return {True: 1}
1575
1576    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1577    def test_dict_bool_conversion(self):
1578        def if_predicate(d: Dict[int, int]):
1579            if d:
1580                s, t = 0, 0
1581                for k, v in d.items():
1582                    s += k
1583                    t += v
1584
1585                return s, t
1586            else:
1587                return -1, -1
1588
1589        self.checkScript(if_predicate, ({1: 2, 3: 5},))
1590        self.checkScript(if_predicate, ({},))
1591
1592        def while_predicate(d: Dict[int, int]):
1593            while d:
1594                d.clear()
1595
1596        self.checkScript(while_predicate, ({1: 2, 3: 5},))
1597        self.checkScript(while_predicate, ({},))
1598
1599        def ternary_predicate(d: Dict[int, int]):
1600            return "non-empty" if d else "empty"
1601
1602        self.checkScript(ternary_predicate, ({1: 2, 3: 5},))
1603        self.checkScript(ternary_predicate, ({},))
1604
1605    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1606    def test_del(self):
1607        def inputs():
1608            return {"hi": 2, "bye": 3}
1609
1610        def fn(x: Dict[str, int]) -> Dict[str, int]:
1611            del x["hi"]
1612            return x
1613
1614        python_out = fn(inputs())
1615        # checkScript reuses the same object, but here it's being mutated so do
1616        # it manually
1617        cu = torch.jit.CompilationUnit()
1618        cu.define(dedent(inspect.getsource(fn)))
1619        self.assertEqual(cu.fn(inputs()), python_out)
1620        self.assertEqual(torch.jit.script(fn)(inputs()), python_out)
1621        with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", 'x["hi"]'):
1622            self.checkScript(fn, [{}])
1623
1624    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1625    def test_dict_variance(self):
1626        """
1627        `Dict[T1, _]` is not a subtype of `Dict[T2, _]`, even if `T1` is
1628        a subtype of `T2`; similarly `Dict[_, T1]` would not be a
1629        subtype of `Dict[_, T2]`.
1630
1631        However, if we have a temporary dict object (that is, a dict
1632        comprehension or a dict literal) on the rhs of an assignment
1633        statement, we want to ignore the inferred type of the rhs if we
1634        can prove that: 1) both the lhs and the rhs are dicts with the
1635        same key types (TorchScript has a restricted set of allowed key
1636        types, so we don't need to worry about subtyping relationships
1637        here), and 2) the value type of the dict is a subtype of the
1638        value type of the rhs dict.
1639        """
1640
1641        def test_dictliteral_is_typed_from_annotation():
1642            x: Dict[str, Optional[int]] = {"foo": None, "bar": None, "baz": None}
1643            return x
1644
1645        self.checkScript(test_dictliteral_is_typed_from_annotation, ())
1646
1647        def test_dictcomprehension_is_typed_from_annotation():
1648            metasyntactics = ["foo", "bar", "baz"]
1649            x: Dict[str, Optional[int]] = {  # noqa: C420, RUF025
1650                word: None for word in metasyntactics
1651            }
1652            return x
1653
1654        self.checkScript(test_dictcomprehension_is_typed_from_annotation, ())
1655
1656        def test_dicts_with_different_value_types_are_invariant(self):
1657            x: Dict[str, int] = {"foo": 1, "bar": 2, "baz": 3}
1658            y: Dict[str, Optional[int]] = x
1659            return x
1660
1661        with self.assertRaisesRegex(
1662            RuntimeError,
1663            "Variable 'y' is "
1664            "annotated with type "
1665            r"Dict\[str, Optional\[int\]\] but "
1666            "is being assigned to a value of "
1667            r"type Dict\[str, int\]",
1668        ):
1669            torch.jit.script(test_dicts_with_different_value_types_are_invariant)
1670
1671        def test_dicts_with_different_value_types_are_invariant_recursive(self):
1672            x: Dict[str, int] = {"foo": 1, "bar": 2, "baz": 3}
1673            y: Dict[str, Dict[str, int]] = {"foo": x, "bar": x, "baz": x}
1674            z: Dict[str, Dict[str, Optional[int]]] = y
1675            return x
1676
1677        with self.assertRaisesRegex(
1678            RuntimeError,
1679            "Variable 'z' is "
1680            "annotated with type "
1681            r"Dict\[str, Dict\[str, Optional"
1682            r"\[int\]\]\] but is being assigned"
1683            r" to a value of type Dict\[str, "
1684            r"Dict\[str, int\]\]",
1685        ):
1686            torch.jit.script(
1687                test_dicts_with_different_value_types_are_invariant_recursive
1688            )
1689
1690    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1691    def test_keys(self):
1692        @torch.jit.script
1693        def keys(x: Dict[str, Tensor]) -> List[str]:
1694            return list(x.keys())
1695
1696        self.assertEqual(set(keys(self.dict())), set(self.dict().keys()))
1697
1698        @torch.jit.script
1699        def specialized_list():
1700            li = {1: 1, 2: 2}.keys()
1701            li.append(3)
1702            return li
1703
1704        self.assertTrue(set(specialized_list()) == {1, 2, 3})
1705
1706    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1707    def test_values(self):
1708        @torch.jit.script
1709        def values(x: Dict[str, Tensor]) -> List[Tensor]:
1710            return list(x.values())
1711
1712        the_dict = self.dict()
1713        self.assertEqual(set(values(the_dict)), set(the_dict.values()))
1714
1715    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1716    def test_len(self):
1717        def length(x: Dict[str, Tensor]) -> int:
1718            return len(x)
1719
1720        self.checkScript(length, (self.dict(),))
1721
1722    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1723    def test_copy(self):
1724        def func(x: Dict[str, Tensor]) -> Dict[str, Tensor]:
1725            return x.copy()
1726
1727        self.checkScript(func, (self.dict(),))
1728
1729    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1730    def test_items(self):
1731        def func(x: Dict[str, Tensor]) -> List[Tuple[str, Tensor]]:
1732            return x.items()
1733
1734        # The value returned by Python is in arbitrary order, so we can't use
1735        # checkScript
1736        scripted_func = torch.jit.script(func)
1737
1738        eager_out = func(self.dict())
1739        script_out = scripted_func(self.dict())
1740
1741        self.assertEqual(len(eager_out), len(script_out))
1742        for item in eager_out:
1743            self.assertTrue(item in script_out)
1744
1745    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1746    def test_pop(self):
1747        def pop(x: Dict[str, Tensor], key: str) -> Tuple[Tensor, Dict[str, Tensor]]:
1748            return x.pop(key), x
1749
1750        # checkScript doesn't copy the inputs, so we can't use it since this mutates
1751        # the dict
1752        def tester(fn, *args):
1753            eager_out = fn(self.dict(), *args)
1754            script_out = torch.jit.script(fn)(self.dict(), *args)
1755            self.assertEqual(eager_out, script_out)
1756
1757        tester(pop, "a")
1758
1759        with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", "x.pop"):
1760            torch.jit.script(pop)(self.dict(), "x")
1761
1762        def default_pop(
1763            x: Dict[str, Tensor], key: str, default: Tensor
1764        ) -> Tuple[Tensor, Dict[str, Tensor]]:
1765            return x.pop(key, default), x
1766
1767        tester(default_pop, "a", torch.randn(2, 2))
1768        tester(default_pop, "x", torch.randn(2, 2))
1769
1770    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1771    def test_setdefault(self):
1772        def setdefault(
1773            x: Dict[str, Tensor], key: str, default: Tensor
1774        ) -> Dict[str, Tensor]:
1775            x.setdefault(key, default)
1776            return x
1777
1778        self.checkScript(setdefault, (self.dict(), "a", torch.randn(2, 2)))
1779        self.checkScript(setdefault, (self.dict(), "nonexistant", torch.randn(2, 2)))
1780
1781    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1782    def test_update(self):
1783        def update(
1784            a: Dict[str, Tensor], b: Dict[str, Tensor]
1785        ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
1786            a.update(b)
1787            return a, b
1788
1789        self.checkScript(update, (self.dict(), self.dict()))
1790        self.checkScript(update, (self.dict(), self.dict2()))
1791
1792    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1793    def test_update_existing_key(self):
1794        def foo() -> Dict[str, int]:
1795            a: Dict[str, int] = {}
1796            for i in range(3):
1797                a.update({"a": i})
1798            return a
1799
1800        self.checkScript(foo, ())
1801
1802    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1803    def test_aug_assign(self):
1804        def aug_assign_dict_tensor(a: Dict[str, Tensor]) -> Dict[str, Tensor]:
1805            a["a"] += 1
1806            a["b"] -= 12
1807            a["c"] *= 122
1808            a["c"] /= 2
1809            a["c"] %= 2
1810            return a
1811
1812        def aug_assign_dict_prim(a: Dict[str, float]) -> Dict[str, float]:
1813            a["a"] += 3.4
1814            a["b"] -= 2.4
1815            a["c"] *= 3.0
1816            a["c"] /= 2.0
1817            a["c"] %= 2.0
1818            return a
1819
1820        self.checkScript(aug_assign_dict_tensor, (self.dict(),))
1821        self.checkScript(aug_assign_dict_prim, ({"a": 3.0, "b": 2.0, "c": 4.0},))
1822
1823    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1824    def test_popitem(self):
1825        @torch.jit.script
1826        def popitem(
1827            x: Dict[str, Tensor]
1828        ) -> Tuple[Tuple[str, Tensor], Dict[str, Tensor]]:
1829            item = x.popitem()
1830            return item, x
1831
1832        # The value returned by Python is arbitrary, so we can't use checkScript
1833        eager_in = self.dict()
1834        eager_out = (eager_in.popitem(), eager_in)
1835
1836        script_out = popitem(self.dict())
1837
1838        # Check that an item was removed
1839        self.assertEqual(len(eager_out[1]), len(script_out[1]))
1840
1841        # Check that the item is the correct types
1842        self.assertTrue(isinstance(script_out[0][0], str))
1843        self.assertTrue(isinstance(script_out[0][1], torch.Tensor))
1844
1845    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1846    def test_clear(self):
1847        def clear(x: Dict[str, Tensor]) -> Dict[str, Tensor]:
1848            x.clear()
1849            return x
1850
1851        self.checkScript(clear, (self.dict(),))
1852
1853    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1854    def test_get(self):
1855        def get(x: Dict[str, Tensor], key: str) -> Optional[Tensor]:
1856            return x.get(key)
1857
1858        self.checkScript(get, (self.dict(), "a"))
1859        self.checkScript(get, (self.dict(), "doesn't exist"))
1860
1861        def get_default(x: Dict[str, Tensor], key: str) -> Optional[Tensor]:
1862            return x.get(key, torch.randn(2, 2))
1863
1864        self.checkScript(get, (self.dict(), "a"))
1865        self.checkScript(get, (self.dict(), "doesn't exist"))
1866
1867    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1868    def test_get_boolkey(self):
1869        def get(x: Dict[bool, int], key: bool) -> Optional[int]:
1870            return x.get(key)
1871
1872        self.checkScript(get, (self.dict_bool(), True))
1873        self.checkScript(get, (self.dict_bool(), False))
1874
1875        def get_default(x: Dict[bool, int], key: bool) -> int:
1876            return x.get(key, 42)
1877
1878        self.checkScript(get_default, (self.dict_bool(), True))
1879        self.checkScript(get_default, (self.dict_bool(), False))
1880
1881    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1882    def test_basic(self):
1883        def simple(x: Dict[str, int]) -> Dict[str, int]:
1884            return x
1885
1886        self.checkScript(simple, ({"item": 20, "other_item": 120},))
1887
1888        def index(x: Dict[str, int]) -> int:
1889            return x["item"]
1890
1891        self.checkScript(index, ({"item": 20, "other_item": 120},))
1892
1893        def type_default() -> Dict[str, Tensor]:
1894            return {}
1895
1896        self.checkScript(type_default, ())
1897
1898        @torch.jit.script
1899        def missing_index(x: Dict[str, int]) -> int:
1900            return x["dne"]
1901
1902        with self.assertRaisesRegexWithHighlight(RuntimeError, "KeyError", 'x["dne"'):
1903            missing_index({"item": 20, "other_item": 120})
1904
1905        code = dedent(
1906            """
1907            def literal1():
1908                return torch.jit.annotate(Dict[int, float], {})
1909            def literal2():
1910                return torch.jit.annotate(Dict[int, float], {10: 1.2})
1911        """
1912        )
1913        cu = torch.jit.CompilationUnit(code)
1914        self.assertEqual({}, cu.literal1())
1915        self.assertEqual({10: 1.2}, cu.literal2())
1916
1917        cu = torch.jit.CompilationUnit(
1918            dedent(
1919                """
1920            def literal3():
1921                return torch.jit.annotate(Dict[int, float], {10: 1.2, 11: 1.3})
1922        """
1923            )
1924        )
1925        self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3())
1926
1927        def list_of_dicts() -> List[Dict[str, Tensor]]:
1928            return [{"word": torch.ones(2) + 3}, {"other word": torch.ones(1) + 2}]
1929
1930        self.checkScript(list_of_dicts, ())
1931
1932    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1933    def test_mutability(self):
1934        @torch.jit.script
1935        def fn() -> Dict[str, int]:
1936            a = torch.jit.annotate(Dict[str, int], {})
1937            a["ok"] = 10
1938            return a
1939
1940        self.assertEqual(fn(), {"ok": 10})
1941
1942    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1943    def test_key_type(self):
1944        with self.assertRaisesRegexWithHighlight(
1945            RuntimeError, "but instead found type", "a[None]"
1946        ):
1947
1948            @torch.jit.script
1949            def fn(a: Dict[str, int]) -> int:
1950                return a[None]
1951
1952    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1953    def test_loop(self):
1954        @torch.jit.script
1955        def fn(x: int) -> Dict[str, int]:
1956            a = torch.jit.annotate(Dict[str, int], {})
1957            for i in range(x):
1958                a["ok"] = i
1959            return a
1960
1961        self.assertEqual(fn(10), {"ok": 9})
1962
1963    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1964    def test_view(self):
1965        def fn(x, y):
1966            l = {"a": x}
1967            x_view = l["a"]
1968            a = x + x
1969            x_view.add_(y)
1970            b = x + x
1971            return a == b
1972
1973        self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
1974
1975    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
1976    def test_membership(self):
1977        def fn(x: Dict[int, int], y: int) -> int:
1978            return x.get(y, 3)
1979
1980        d = {1: 2, 3: 4}
1981        self.checkScript(fn, (d, 3))
1982        self.checkScript(fn, (d, 2))
1983
1984        def optional(x: Dict[int, int], y: int) -> bool:
1985            res = x.get(y)
1986            return res is None
1987
1988        self.checkScript(fn, (d, 3))
1989        self.checkScript(fn, (d, 2))
1990
1991        with self.assertRaisesRegexWithHighlight(
1992            RuntimeError, "is actually of type Optional", "return x.get(y"
1993        ):
1994
1995            @torch.jit.script
1996            def bad_types(x: Dict[int, int], y: int) -> int:
1997                return x.get(y)  # noqa: T484
1998
1999    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
2000    def test_dict_to_python(self):
2001        @torch.jit.ignore
2002        def python_lookup(my_dict: Dict[str, int], keys: List[str]) -> List[int]:
2003            return [my_dict[k] for k in keys]
2004
2005        def fn(my_dict: Dict[str, int], keys: List[str]) -> List[int]:
2006            return python_lookup(my_dict, keys)
2007
2008        a_dict = {"a": torch.ones(1), "b": torch.ones(1) + 1, "c": torch.ones(1) + 2}
2009        self.checkScript(fn, (a_dict, ("a", "c")))
2010
2011    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
2012    def test_ordered_dict(self):
2013        def test_func(fn, inputs):
2014            self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs))
2015
2016        def repeated_key():
2017            return OrderedDict([(1, 2), (2, 3), (1, 4)])
2018
2019        test_func(repeated_key, ())
2020
2021        def no_args():
2022            a = OrderedDict()
2023            a["one"] = torch.tensor(1)
2024            a["two"] = torch.tensor(2)
2025
2026        test_func(no_args, ())
2027
2028        def test_dict_constructor():
2029            a = dict()  # noqa: C408
2030            a["one"] = torch.tensor(1)
2031            return a, dict([(1, 2), (2, 3), (1, 4)])  # noqa: C406
2032
2033        test_func(test_dict_constructor, ())
2034
2035        def test_dict_initializer_list():
2036            a = {"1": torch.tensor(1), "2": torch.tensor(2)}
2037            output_order = []
2038            for key in a:
2039                output_order.append(a[key])
2040            return output_order
2041
2042        test_func(test_dict_initializer_list, ())
2043
2044        def test_dict_error():
2045            a = dict()  # noqa: C408
2046            a[1] = 2
2047            return a
2048
2049        with self.assertRaisesRegexWithHighlight(
2050            Exception, "Arguments for call are not", "a[1] = 2"
2051        ):
2052            torch.jit.script(test_dict_error)
2053
2054    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
2055    def test_type_annotation_missing_contained_type(self):
2056        """
2057        Test that the use of a Dict type annotation without contained
2058        key and value types produces an error.
2059        """
2060
2061        # This function uses a type comment.
2062        def fn_with_comment(input: Dict) -> Any:
2063            return input
2064
2065        # This function uses Python3 style type annotations.
2066        def annotated_fn(input: Dict) -> Any:
2067            return input
2068
2069        with self.assertRaisesRegex(
2070            RuntimeError, r"Attempted to use Dict without contained types"
2071        ):
2072            cu = torch.jit.CompilationUnit()
2073            cu.define(dedent(inspect.getsource(fn_with_comment)))
2074
2075        with self.assertRaisesRegex(
2076            RuntimeError, r"Attempted to use Dict without contained types"
2077        ):
2078            cu = torch.jit.CompilationUnit()
2079            cu.define(dedent(inspect.getsource(annotated_fn)))
2080
2081        with self.assertRaisesRegex(
2082            RuntimeError, r"Attempted to use Dict without contained types"
2083        ):
2084            m = torch.jit.script(fn_with_comment)
2085
2086        with self.assertRaisesRegex(
2087            RuntimeError, r"Attempted to use Dict without contained types"
2088        ):
2089            m = torch.jit.script(annotated_fn)
2090
2091    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
2092    def test_dict_preserves_order(self):
2093        def dict_ordering():
2094            a: Dict[int, int] = {}
2095            for i in range(1000):
2096                a[i] = i + 1
2097            return a
2098
2099        self.checkScript(dict_ordering, ())
2100        di = torch.jit.script(dict_ordering)()
2101        res = list(di.items())
2102        for i in range(1000):
2103            key, value = res[i]
2104            self.assertTrue(key == i and value == i + 1)
2105
2106    @skipIfTorchDynamo("TorchDynamo fails for this test for unknown reason")
2107    def test_optional_dict_construct(self):
2108        class M(torch.nn.Module):
2109            def use(self, buffer: Dict[str, Optional[torch.Tensor]]):
2110                return buffer["prev_key"]
2111
2112            def forward(self, x):
2113                prev_key = torch.rand(2, 3)
2114                next_key = torch.rand(2, 3)
2115                saved_state: Dict[str, Optional[torch.Tensor]] = {
2116                    "prev_key": prev_key,
2117                    "next_key": next_key,
2118                }
2119
2120                return self.use(saved_state)
2121
2122        self.checkModule(M(), (torch.rand(2, 2),))
2123
2124
2125class TestNamedTuple(JitTestCase):
2126    def test_namedtuple(self):
2127        class FeatureVector(NamedTuple):
2128            float_features: float
2129            sequence_features: List[float]
2130            time_since_first: float
2131
2132        @torch.jit.script
2133        def foo(x) -> float:
2134            fv = FeatureVector(3.0, [3.0], 3.0)
2135            rv = fv.float_features
2136            for val in fv.sequence_features:
2137                rv += val
2138            rv *= fv.time_since_first
2139            return rv
2140
2141        self.assertEqual(foo(torch.rand(3, 4)), 18.0)
2142
2143    def test_namedtuple_constant(self):
2144        class Tup(NamedTuple):
2145            a: int
2146            b: int
2147
2148        @torch.jit.script
2149        def foo():
2150            return Tup(1, 2)
2151
2152        self.assertEqual(foo(), Tup(1, 2))
2153
2154    def test_return_named_tuple(self):
2155        class FeatureVector(NamedTuple):
2156            float_features: float
2157            sequence_features: List[float]
2158            time_since_first: float
2159
2160        @torch.jit.script
2161        def foo(x):
2162            fv = FeatureVector(3.0, [3.0], 3.0)
2163            return fv
2164
2165        out = foo(torch.rand(3, 4))
2166        out = foo(torch.rand(3, 4))
2167        self.assertEqual(out.float_features, 3.0)
2168        self.assertEqual(out.sequence_features, [3.0])
2169        self.assertEqual(out.time_since_first, 3.0)
2170
2171    def test_namedtuple_as_attr(self):
2172        class Config(NamedTuple):
2173            size: int
2174
2175        class MyMod(nn.Module):
2176            configs: Dict[int, Config]
2177
2178            def __init__(self, configs):
2179                super().__init__()
2180                self.configs = configs
2181
2182            def forward(self, x):
2183                for config in self.configs.values():
2184                    x += config.size
2185                return x
2186
2187        s = torch.jit.script(MyMod({0: Config(size=16)}))
2188
2189    def test_namedtuple_resolution(self):
2190        class TheType(NamedTuple):
2191            t: int
2192
2193        class MyModule(types.ModuleType):
2194            def __init__(self) -> None:
2195                super().__init__("MyModule")
2196
2197            def __getattr__(self, attr):
2198                return TheType
2199
2200        some_module = MyModule()
2201
2202        def fn() -> some_module.Type:
2203            return some_module.Type(1)
2204
2205        self.checkScript(fn, [])
2206
2207    def test_namedtuple_slice_unpack(self):
2208        class MyCoolNamedTuple(NamedTuple):
2209            a: int
2210            b: float
2211            c: List[int]
2212
2213        @torch.jit.script
2214        def foo(a: int, b: float, c: List[int]):
2215            tup = MyCoolNamedTuple(a, b, c)
2216            my_a, my_b, my_c = tup
2217            return tup[:1], my_a, my_c
2218
2219        self.assertEqual(foo(3, 3.5, [6]), ((3,), 3, [6]))
2220
2221    def test_namedtuple_lower(self):
2222        class MyCoolNamedTuple(NamedTuple):
2223            a: int
2224            b: float
2225            c: List[int]
2226
2227        @torch.jit.script
2228        def foo(a: int):
2229            tup = MyCoolNamedTuple(a, 3.14, [9])
2230            return tup
2231
2232        FileCheck().check("TupleConstruct").run(foo.graph)
2233        torch._C._jit_pass_lower_all_tuples(foo.graph)
2234        FileCheck().check_not("TupleConstruct").run(foo.graph)
2235
2236    def test_namedtuple_type_annotation(self):
2237        global MyCoolNamedTuple  # see [local resolution in python]
2238
2239        class MyCoolNamedTuple(NamedTuple):
2240            a: int
2241            b: float
2242            c: List[int]
2243
2244        @torch.jit.script
2245        def foo(x: MyCoolNamedTuple) -> MyCoolNamedTuple:
2246            return x
2247
2248        mnt = MyCoolNamedTuple(42, 420.0, [666])
2249        self.assertEqual(foo(mnt), mnt)
2250
2251    def test_namedtuple_wrong_types(self):
2252        class MyCoolNamedTuple(NamedTuple):
2253            a: int
2254            b: float
2255            c: List[int]
2256
2257        with self.assertRaisesRegex(
2258            RuntimeError,
2259            "Expected a value of type 'int' for argument 'a'"
2260            " but instead found type 'str'",
2261        ):
2262
2263            @torch.jit.script
2264            def foo():
2265                tup = MyCoolNamedTuple("foo", "bar", "baz")
2266                return tup
2267
2268    def test_namedtuple_kwarg_construct(self):
2269        class MyCoolNamedTuple(NamedTuple):
2270            a: int
2271            b: float
2272            c: List[int]
2273
2274        @torch.jit.script
2275        def foo():
2276            tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9)
2277            return tup
2278
2279        tup = foo()
2280        self.assertEqual(tup.a, 9)
2281        self.assertEqual(tup.b, 3.5)
2282        self.assertEqual(tup.c, [1, 2, 3])
2283
2284    @unittest.skipIf(True, "broken while these tests were not in CI")
2285    def test_namedtuple_serialization(self):
2286        class MyCoolNamedTuple(NamedTuple):
2287            a: int
2288            b: float
2289            c: List[int]
2290
2291        class MyMod(torch.jit.ScriptModule):
2292            @torch.jit.script_method
2293            def forward(self):
2294                return MyCoolNamedTuple(3, 3.5, [3, 4, 5])
2295
2296        mm = MyMod()
2297        mm.save("foo.zip")
2298        torch.testing._internal.jit_utils.clear_class_registry()
2299        loaded = torch.jit.load("foo.zip")
2300
2301        out = mm()
2302        out_loaded = loaded()
2303
2304        for name in ["a", "b", "c"]:
2305            self.assertEqual(getattr(out_loaded, name), getattr(out, name))
2306
2307    def test_namedtuple_inside_forwardref(self):
2308        class FeatureVector(NamedTuple):
2309            float_features: "float"
2310            sequence_features: "List[float]"
2311            time_since_first: "float"
2312
2313        @torch.jit.script
2314        def foo(x) -> float:
2315            fv = FeatureVector(3.0, [3.0], 3.0)
2316            rv = fv.float_features
2317            for val in fv.sequence_features:
2318                rv += val
2319            rv *= fv.time_since_first
2320            return rv
2321
2322        self.assertEqual(foo(torch.rand(3, 4)), 18.0)
2323
2324    def test_namedtuple_input_forwardref(self):
2325        class MyNamedTuple(NamedTuple):
2326            a: "int"
2327            b: "float"
2328            c: "torch.Tensor"
2329
2330        make_global(MyNamedTuple)
2331
2332        nt = MyNamedTuple(4, 2.5, torch.rand((2, 2)))
2333
2334        def fn(obj: MyNamedTuple):
2335            return ((obj.c + obj.b) ** obj.a).sin()
2336
2337        expected = fn(nt)
2338        fn_s = torch.jit.script(fn)
2339        actual = fn_s(nt)
2340        self.assertEqual(expected, actual)
2341
2342    # see #95858
2343    @unittest.expectedFailure
2344    def test_namedtuple_resolution_forwardref(self):
2345        class TheType(NamedTuple):
2346            t: "int"
2347
2348        class MyModule(types.ModuleType):
2349            def __init__(self) -> None:
2350                super().__init__("MyModule")
2351
2352            def __getattr__(self, attr):
2353                return TheType
2354
2355        some_module = MyModule()
2356
2357        def fn() -> some_module.Type:
2358            return some_module.Type(1)
2359
2360        self.checkScript(fn, [])
2361
2362
2363class TestScriptDict(JitTestCase):
2364    """
2365    This class contains a suite of tests for torch.jit.script, a
2366    function that returns a dictionary-like object that has reference
2367    semantics across the Python/TorchScript boundary. That is,
2368    it can be passed to a TorchScript function that mutates it
2369    and those modifications are visible in the scope of the Python
2370    caller of said TorchScript function.
2371
2372    The vast majority of tests are for making sure that objects returned
2373    by torch.jit.script behave like dictionaries do so that they are fungible
2374    in almost all cirumstances with regular dictionaries.
2375    """
2376
2377    def _script_dict_add(self, d: torch._C.ScriptDict, k: int, v: int):
2378        """
2379        This is a helper function that inserts the pair (k, v) into the
2380        dictionary d in TorchScript. It is used for testing reference
2381        semantics.
2382        """
2383
2384        @torch.jit.script
2385        def dict_add(d: Dict[int, int], k: int, v: int):
2386            d[k] = v
2387
2388        dict_add(d, k, v)
2389
2390    def _compare_eager_and_script(self, fn, input_dict, script_input_dict=None):
2391        """
2392        This is a helper function that facilitates comparing behaviour between
2393        Python dictionaries and "scripted" dictionaries.
2394
2395        Args:
2396            fn: The function to test and compare the behaviour of.
2397            input_dict: The input dictionary to use for the test (passed to fn).
2398            script_input_dict: The scripted input dictionary to use for the tests.
2399                                If None, input_dict is scripted with torch.jit.script
2400                                and used instead.
2401        """
2402        # Create ScriptDict version of input_dict if needed.
2403        script_input_dict = script_input_dict or torch.jit.script(input_dict)
2404
2405        # Run fn with both input_dict and scripted_dict.
2406        eager_raised, script_raised = False, False
2407
2408        try:
2409            eager_out = fn(input_dict)
2410        except Exception as e:
2411            eager_exception = e
2412            eager_raised = True
2413
2414        try:
2415            script_out = fn(script_input_dict)
2416        except Exception as e:
2417            script_exception = e
2418            script_raised = True
2419
2420        # Check that both calls raised or none of them raised.
2421        self.assertEqual(eager_raised, script_raised)
2422
2423        if eager_raised:
2424            # If fn raised an exception, it should be the same between
2425            # regular and scripted dictionaries.
2426            self.assertEqual(type(eager_exception), type(script_exception))
2427        else:
2428            # Otherwise, make sure the outputs match and the dictionaries
2429            # match (the latter may not be the same as the output).
2430            self.assertEqual(eager_out, script_out)
2431            self.assertEqual(input_dict, script_input_dict)
2432
2433    def test_repr(self):
2434        """
2435        Test the __repr__ method.
2436        """
2437        self._compare_eager_and_script(lambda d: repr(d), {1: 2})
2438
2439    def test_bool(self):
2440        """
2441        Test the __bool__ method. This should return True
2442        if the dictionary is non-empty and False otherwise.
2443        """
2444        self._compare_eager_and_script(lambda d: bool(d), {1: 2})
2445        self._compare_eager_and_script(lambda d: bool(d), {})
2446
2447    def test_iter(self):
2448        """
2449        Test iteration over a dictionary's keys.
2450        """
2451
2452        def sum_keys(input_dict):
2453            s = 0
2454            for k in input_dict:
2455                s += k
2456
2457            return s
2458
2459        self._compare_eager_and_script(sum_keys, {1: 2, 3: 4})
2460
2461    def test_items(self):
2462        """
2463        Test .items().
2464        """
2465
2466        def sum_pair_product(input_dict):
2467            s = 0
2468            for k, v in input_dict.items():
2469                s += k * v
2470
2471            return s
2472
2473        self._compare_eager_and_script(sum_pair_product, {1: 2, 3: 4})
2474
2475    def test_getitem(self):
2476        """
2477        Test accessing dictionary values using the [] operator.
2478        """
2479        data = {1: 2, 3: 4}
2480        self._compare_eager_and_script(lambda d: d[1], data)
2481        self._compare_eager_and_script(lambda d: d[4], data)
2482        self._compare_eager_and_script(lambda d: d[2], data)
2483        self._compare_eager_and_script(lambda d: d["key"], data)
2484
2485    def test_setitem(self):
2486        """
2487        Test setting dictionary values using the [] operator.
2488        """
2489        data = {1: 2, 3: 4}
2490
2491        def fn(input_dict):
2492            input_dict[1] = 10
2493            input_dict[3] = 11
2494
2495        self._compare_eager_and_script(fn, data)
2496
2497        # Check that using improperly typed keys and values
2498        # throws TypeError.
2499        # _compare_eager_and_script cannot be used here since
2500        # the following uses of __setitem__ are valid in
2501        # Python.
2502        script_data = torch.jit.script(data)
2503
2504        with self.assertRaises(TypeError):
2505            script_data["str"] = 3
2506
2507        with self.assertRaises(TypeError):
2508            script_data[3] = "str"
2509
2510    def test_contains(self):
2511        """
2512        Test membership checks (x in y, x not in y).
2513        """
2514        data = {1: 2, 3: 4}
2515
2516        def fn(input_dict):
2517            return (
2518                1 in input_dict,
2519                2 not in input_dict,
2520                3 in input_dict,
2521                4 not in input_dict,
2522            )
2523
2524        self._compare_eager_and_script(fn, data)
2525
2526        # Check that using an improperly typed key
2527        # throws KeyError.
2528        script_data = torch.jit.script(data)
2529
2530        with self.assertRaises(KeyError):
2531            a = "str" in script_data
2532
2533    def test_delitem(self):
2534        """
2535        Test deletion.
2536        """
2537        data = {1: 2, 3: 4}
2538
2539        def del_fn(input_dict):
2540            del input_dict[1]
2541
2542        def del_fn_raises(input_dict):
2543            del input_dict[10]
2544
2545        self._compare_eager_and_script(del_fn, data)
2546        self._compare_eager_and_script(del_fn_raises, data)
2547
2548        # Check that using an improperly typed key
2549        # throws TypeError.
2550        script_data = torch.jit.script(data)
2551
2552        with self.assertRaises(TypeError):
2553            del script_data["str"]
2554
2555    def test_len(self):
2556        """
2557        Test len() builtin function.
2558        """
2559        self._compare_eager_and_script(lambda d: len(d), {1: 2})
2560        self._compare_eager_and_script(lambda d: len(d), {})
2561
2562    @unittest.skip(
2563        "Cannot pass until all dicts returned from TorchScript are ScriptDicts"
2564    )
2565    def test_nested(self):
2566        """
2567        Test that reference semantics are honoured when the ScriptDict that is
2568        mutated using TorchScript is inside another.
2569        """
2570        nested = torch.jit.script(
2571            {1: {1: 2}, 2: {3: 4}}, type_hint=Dict[int, Dict[int, int]]
2572        )
2573
2574        one = nested[1]
2575        two = nested[2]
2576
2577        self._script_dict_add(one, 9, 10)
2578        self._script_dict_add(two, 11, 12)
2579
2580        # The mutation should be visible in the original dictionary, nested.
2581        self.assertEqual(len(one), 2)
2582        self.assertEqual(len(two), 2)
2583        self.assertEqual(len(nested[1]), 2)
2584        self.assertEqual(len(nested[2]), 2)
2585
2586    def test_reference_semantics(self):
2587        """
2588        Test that reference semantics are honoured; that modifications made
2589        to a ScriptDict in TorchScript are visible in Python.
2590        """
2591        data = torch.jit.script({1: 2})
2592        self._script_dict_add(data, 3, 4)
2593
2594        # The mutation should be visible in the original dictionary.
2595        self.assertEqual(len(data), 2)
2596        self.assertTrue(3 in data)
2597        self.assertEqual(data[3], 4)
2598
2599
2600class TestScriptList(JitTestCase):
2601    """
2602    This class contains a suite of tests for torch._C.ScriptList, a
2603    function that returns a list-like object that has reference
2604    semantics across the Python/TorchScript boundary. That is,
2605    it can be passed to a TorchScript function that mutates it
2606    and those modifications are visible in the scope of the Python
2607    caller of said TorchScript function.
2608
2609    The vast majority of tests are for making sure that instances of
2610    torch._C.ScriptList behave like lists do so that they are fungible
2611    in almost all cirumstances with regular list.
2612    """
2613
2614    def _script_list_add(self, l: torch._C.ScriptList, e: int):
2615        """
2616        This is a helper function that inserts the element e into the
2617        list l in TorchScript. It is used for testing reference
2618        semantics.
2619        """
2620
2621        @torch.jit.script
2622        def list_add(l: List[int], e: int):
2623            l.append(e)
2624
2625        list_add(l, e)
2626
2627    def _compare_eager_and_script(self, fn, input_list, script_input_list=None):
2628        """
2629        This is a helper function that facilitates comparing behaviour between
2630        Python lists and "scripted" lists.
2631        Args:
2632            fn: The function to test and compare the behaviour of.
2633            input_list: The input list to use for the test (passed to fn).
2634            script_input_list: The scripted input list to use for the tests.
2635                                If None, input_list is scripted with torch.jit.script
2636                                and used instead.
2637        """
2638        # Create ScriptDict version of input_list if needed.
2639        script_input_list = script_input_list or torch.jit.script(input_list)
2640
2641        # Run fn with both input_list and scripted_dict.
2642        eager_raised, script_raised = False, False
2643
2644        try:
2645            eager_out = fn(input_list)
2646        except Exception as e:
2647            eager_exception = e
2648            eager_raised = True
2649
2650        try:
2651            script_out = fn(script_input_list)
2652        except Exception as e:
2653            script_exception = e
2654            script_raised = True
2655
2656        # Check that both calls raised or none of them raised.
2657        self.assertEqual(eager_raised, script_raised)
2658
2659        if eager_raised:
2660            # If fn raised an exception, it should be the same between
2661            # regular and scripted lists.
2662            self.assertEqual(type(eager_exception), type(script_exception))
2663        else:
2664            # Otherwise, make sure the outputs match and the lists
2665            # match (the latter may not be the same as the output).
2666            self.assertEqual(eager_out, script_out)
2667            self.assertEqual(input_list, script_input_list)
2668
2669    def test_repr(self):
2670        """
2671        Test the __repr__ method.
2672        """
2673        self._compare_eager_and_script(lambda l: repr(l), [1])
2674
2675    def test_bool(self):
2676        """
2677        Test the __bool__ method. This should return True
2678        if the list is non-empty and False otherwise.
2679        """
2680        self._compare_eager_and_script(lambda l: bool(l), [1])
2681        self._compare_eager_and_script(lambda l: bool(l), [])
2682
2683    def test_iter(self):
2684        """
2685        Test iteration over a list's elements.
2686        """
2687
2688        def sum_elements(input_list):
2689            s = 0
2690            for k in input_list:
2691                s += k
2692
2693            return s
2694
2695        self._compare_eager_and_script(sum_elements, [1, 2, 3, 4])
2696
2697    def test_getitem(self):
2698        """
2699        Test accessing list elements using the [] operator.
2700        """
2701        data = [1, 2, 3, 4]
2702
2703        # Test regular indexing.
2704        self._compare_eager_and_script(lambda l: l[1], data)
2705        self._compare_eager_and_script(lambda l: l[3], data)
2706        self._compare_eager_and_script(lambda l: l[-1], data)
2707
2708        # Test slicing.
2709        self._compare_eager_and_script(lambda l: l[1:3], data)
2710        self._compare_eager_and_script(lambda l: l[:], data)
2711        self._compare_eager_and_script(lambda l: l[1:], data)
2712        self._compare_eager_and_script(lambda l: l[:2], data)
2713        self._compare_eager_and_script(lambda l: l[-1], data)
2714        self._compare_eager_and_script(lambda l: l[-1::-1], data)
2715
2716        # Test errors.
2717        self._compare_eager_and_script(lambda l: l[5], data)
2718        self._compare_eager_and_script(lambda l: l[-7], data)
2719        self._compare_eager_and_script(lambda l: l["key"], data)
2720
2721    def test_setitem(self):
2722        """
2723        Test setting list elements using the [] operator.
2724        """
2725        data = [1, 2, 3, 4]
2726
2727        # Test regular assignment.
2728        def setitem(input_list):
2729            input_list[1] = 10
2730            input_list[3] = 11
2731            input_list[-1] = 12
2732
2733        self._compare_eager_and_script(setitem, data.copy())
2734
2735        # Test slice assignment.
2736        # TODO: Something like input_list[:1] = [1, 2, 3, 4, 5]
2737        # is allowed in Python, but pybind11/stl_bind.h does not
2738        # allow it. Should we?
2739        def setitem_slice(input_list):
2740            input_list[:4:2] = [10, 11]
2741            input_list[-2:] = [15, 16]
2742
2743        self._compare_eager_and_script(setitem_slice, data)
2744
2745        # Test errors.
2746        def out_of_range(input_list):
2747            input_list[11] = 3
2748
2749        def out_of_range_negative(input_list):
2750            input_list[-11] = 3
2751
2752        def wrong_index_type(input_list):
2753            input_list["str"] = 3
2754
2755        self._compare_eager_and_script(out_of_range, data)
2756        self._compare_eager_and_script(out_of_range_negative, data)
2757        self._compare_eager_and_script(wrong_index_type, data)
2758
2759        # Check that using value of an incorrect type throws TypeError.
2760        # _compare_eager_and_script cannot be used here since
2761        # the following use of __setitem__ is valid in
2762        # Python.
2763        script_data = torch.jit.script(data)
2764
2765        with self.assertRaises(TypeError):
2766            script_data[0] = "str"
2767
2768    def test_contains(self):
2769        """
2770        Test membership checks (x in y, x not in y).
2771        """
2772        data = [1, 2, 3, 4]
2773
2774        def fn(input_list):
2775            return (
2776                1 in input_list,
2777                2 not in input_list,
2778                3 in input_list,
2779                4 not in input_list,
2780            )
2781
2782        self._compare_eager_and_script(fn, data)
2783
2784        # Check that using a value of an incorrect type throws a TypeError.
2785        script_data = torch.jit.script(data)
2786
2787        with self.assertRaises(TypeError):
2788            a = "str" in script_data
2789
2790    def test_delitem(self):
2791        """
2792        Test deletion.
2793        """
2794        data = [1, 2, 3, 4]
2795
2796        def del_fn(input_list):
2797            del input_list[1]
2798
2799        def del_fn_out_of_range(input_list):
2800            del input_list[10]
2801
2802        def del_fn_wrong_type(input_list):
2803            del input_list["str"]
2804
2805        self._compare_eager_and_script(del_fn, data.copy())
2806        self._compare_eager_and_script(del_fn_out_of_range, data)
2807        self._compare_eager_and_script(del_fn_wrong_type, data)
2808
2809    def test_len(self):
2810        """
2811        Test len() builtin function.
2812        """
2813        self._compare_eager_and_script(lambda l: len(l), [1, 2, 3, 4])
2814        self._compare_eager_and_script(lambda l: len(l), [])
2815
2816    def test_count(self):
2817        """
2818        Test count method.
2819        """
2820        self._compare_eager_and_script(lambda l: l.count(3), [1, 2, 3, 3])
2821
2822        # Check that using a value of an incorrect type throws TypeError.
2823        script_data = torch.jit.script([1])
2824
2825        with self.assertRaises(TypeError):
2826            script_data.count("str")
2827
2828    def test_remove(self):
2829        """
2830        Test remove method.
2831        """
2832        self._compare_eager_and_script(lambda l: l.remove(1), [1, 2, 3])
2833        self._compare_eager_and_script(lambda l: l.remove(10), [1, 2, 3])
2834
2835        # Check that using a value of an incorrect type throws TypeError.
2836        script_data = torch.jit.script([1])
2837
2838        with self.assertRaises(TypeError):
2839            script_data.remove("str")
2840
2841    def test_append(self):
2842        """
2843        Test append method.
2844        """
2845        self._compare_eager_and_script(lambda l: l.append(1), [4, 3, 2])
2846
2847        # Check that using a value of an incorrect type throws TypeError.
2848        script_data = torch.jit.script([1])
2849
2850        with self.assertRaises(TypeError):
2851            script_data.append("str")
2852
2853    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
2854    def test_clear(self):
2855        """
2856        Test clear.
2857        """
2858        self._compare_eager_and_script(lambda l: l.clear(), [4, 3, 2])
2859
2860    def test_extend(self):
2861        """
2862        Test extend.
2863        """
2864
2865        class Iterable:
2866            def __init__(self, limit: int):
2867                self.limit = limit
2868                self.value = 0
2869
2870            def __iter__(self):
2871                return self
2872
2873            def __next__(self):
2874                if self.value == limit:  # noqa: F821
2875                    raise StopIteration
2876
2877                ret = self.value
2878                self.value += 1
2879                return ret
2880
2881        data = [1, 2, 3]
2882
2883        def extend_list(input_list):
2884            input_list.extend([4, 5, 6])
2885
2886        def extend_dict(input_list):
2887            input_list.extend({4: 10, 5: 11, 6: 12})
2888
2889        def extend_iterable(input_list):
2890            input_list.extend(Iterable(3))
2891
2892        self._compare_eager_and_script(extend_list, data.copy())
2893        self._compare_eager_and_script(extend_dict, data.copy())
2894        self._compare_eager_and_script(extend_iterable, data)
2895
2896        # Check that using a value of an incorrect type throws TypeError.
2897        script_data = torch.jit.script([1])
2898
2899        with self.assertRaises(TypeError):
2900            script_data.extend(["a"])
2901
2902        with self.assertRaises(TypeError):
2903            script_data.extend({"a": 1})
2904
2905    def test_insert(self):
2906        """
2907        Test insert.
2908        """
2909        data = [1, 2, 4]
2910
2911        self._compare_eager_and_script(lambda l: l.insert(3, 3), data.copy())
2912        self._compare_eager_and_script(lambda l: l.insert(0, 3), data.copy())
2913        self._compare_eager_and_script(lambda l: l.insert(-2, 3), data)
2914
2915        # Check that using a value of an incorrect type throws TypeError.
2916        script_data = torch.jit.script([1])
2917
2918        with self.assertRaises(TypeError):
2919            script_data.insert((0, "str"))
2920
2921    def test_pop(self):
2922        """
2923        Test pop.
2924        """
2925        data = [1, 2, 3, 4, 5]
2926
2927        # Test normal cases.
2928        self._compare_eager_and_script(lambda l: l.pop(), data.copy())
2929        self._compare_eager_and_script(lambda l: l.pop(2), data.copy())
2930        self._compare_eager_and_script(lambda l: l.pop(-3), data.copy())
2931
2932        # Test error cases.
2933        self._compare_eager_and_script(lambda l: l.pop(10), data)
2934
2935    @unittest.skip(
2936        "Cannot pass until all list returned from TorchScript are ScriptLists"
2937    )
2938    def test_nested(self):
2939        """
2940        Test that reference semantics are honoured when the ScriptList that is
2941        mutated using TorchScript is inside another.
2942        """
2943        nested = torch.jit.script([[1], [2]], List[List[int]])
2944
2945        one = nested[0]
2946        two = nested[1]
2947
2948        self._script_list_add(one, 3)
2949        self._script_list_add(two, 4)
2950
2951        # The mutation should be visible in the original list, nested.
2952        self.assertEqual(len(one), 2)
2953        self.assertEqual(len(two), 2)
2954        self.assertEqual(one[len(one) - 1], 3)
2955        self.assertEqual(two[len(one) - 1], 4)
2956        self.assertEqual(len(nested[0]), 2)
2957        self.assertEqual(len(nested[1]), 2)
2958
2959    def test_reference_semantics(self):
2960        """
2961        Test that reference semantics are honoured; that modifications made
2962        to a ScriptList in TorchScript are visible in Python.
2963        """
2964        l = torch.jit.script([1, 2])
2965        self._script_list_add(l, 3)
2966
2967        self.assertEqual(len(l), 3)
2968        self.assertTrue(3 in l)
2969        self.assertEqual(l[2], 3)
2970
2971    def test_defaultdict(self):
2972        def get_dict():
2973            test_dict = defaultdict(list)
2974            return test_dict
2975
2976        class Test(torch.nn.Module):
2977            segments_groupby_col: Dict[str, List[str]]
2978
2979            def __init__(self) -> None:
2980                super().__init__()
2981                self.segments_groupby_col = get_dict()
2982                self.col1 = "a"
2983                self.col2 = "b"
2984
2985            def forward(self):
2986                if self.col1 in self.segments_groupby_col.keys():
2987                    return 1
2988                else:
2989                    return 2
2990
2991        test = Test()
2992        test_script = torch.jit.script(test)
2993        test_script.segments_groupby_col
2994
2995        # Smoketest for flakiness. Takes around 2s.
2996        for i in range(300):
2997            test = Test()
2998            test_script = torch.jit.script(test)
2999