xref: /aosp_15_r20/external/pytorch/test/jit/test_peephole.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import unittest
4from typing import Callable, List
5
6import torch
7from torch import nn
8from torch.testing import FileCheck
9from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA
10
11
12if __name__ == "__main__":
13    raise RuntimeError(
14        "This test file is not meant to be run directly, use:\n\n"
15        "\tpython test/test_jit.py TESTNAME\n\n"
16        "instead."
17    )
18
19
20class TestPeephole(JitTestCase):
21    def test_peephole_with_writes(self):
22        def test_write(x):
23            s = 0
24            s += x
25            s += x
26            return s
27
28        self.checkScript(test_write, (torch.ones(4, 4),))
29
30    def test_peephole_with_non_output_writes(self):
31        @torch.jit.ignore
32        def nomnom(x):
33            pass
34
35        def test_write(x):
36            t = torch.ones_like(x)
37            z = x.clone()
38            y = z + 0
39            z.add_(t)
40            # this makes sure z isn't blasted out of existence
41            # because it isn't returned or used in a side-effectful
42            # way
43            nomnom(z)
44            return y + y
45
46        a = torch.ones(4, 4)
47        j = self.checkScript(test_write, (a,))
48
49    def test_peephole_no_output_aliasing(self):
50        def test_peephole(x):
51            y = x + 0
52            return x, y
53
54        a = torch.ones(4, 4)
55        j = self.checkScript(test_peephole, (a,))
56        r1, r2 = j(a)
57        self.assertNotEqual(r1.data_ptr(), r2.data_ptr())
58
59    def test_peephole(self):
60        a = torch.tensor([0.4])
61        b = torch.tensor([0.7])
62        c = torch.tensor([0], dtype=torch.int32)
63
64        def f(x, y):
65            return x.type_as(y)
66
67        tf = torch.jit.trace(f, (a, b))
68        FileCheck().check("type_as").run(str(tf.graph))
69        self.run_pass("peephole", tf.graph)
70        FileCheck().check_not("type_as").run(str(tf.graph))
71        tf2 = torch.jit.trace(f, (a, c))
72        s = str(tf2.graph)
73        self.run_pass("peephole", tf2.graph)
74        self.assertEqual(s, str(s))
75
76    def test_peephole_dynamic(self):
77        def f(x, y):
78            return x.type_as(y)
79
80        fn = torch.jit.script(f)
81        s = str(fn.graph)
82        torch._C._jit_pass_peephole(fn.graph)
83        self.assertEqual(s, str(fn.graph))
84
85    def test_peephole_list_ops(self):
86        @torch.jit.script
87        def foo(x, y, z):
88            return len([x, y, z])
89
90        self.run_pass("peephole", foo.graph)
91        FileCheck().check("value=3").check_next("return").run(foo.graph)
92
93        @torch.jit.script
94        def foo(x, y, z):
95            li = [x, y, z]
96            for i in range(len(x)):
97                li.append(x)
98            return len([x, y, z])
99
100        self.run_pass("peephole", foo.graph)
101        FileCheck().check_not("aten::len").run(foo.graph)
102
103        @torch.jit.script
104        def foo(x, y, z):
105            li = [x, y, z]
106            return li[1], li[-2]
107
108        FileCheck().check("aten::__getitem__").run(foo.graph)
109        self.run_pass("peephole", foo.graph)
110        FileCheck().check_not("aten::__getitem__").run(foo.graph)
111
112        @torch.jit.script
113        def foo(x, y, z):
114            li = [x, y, z]
115            return li[-7]
116
117        self.run_pass("peephole", foo.graph)
118        FileCheck().check("aten::__getitem__").run(foo.graph)
119
120        @torch.jit.script
121        def foo(x, y, z):
122            li = [x, y, z]
123            for i in range(len(x)):
124                li.append(x)
125            return li[-2]
126
127        self.run_pass("peephole", foo.graph)
128        FileCheck().check("aten::__getitem__").run(foo.graph)
129
130    @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
131    def test_peephole_cuda(self):
132        a = torch.tensor([0.4], device="cpu")
133        b = torch.tensor([0.7], device="cuda")
134        c = torch.tensor([0.7], device="cuda")
135
136        def f(x, y):
137            return x.type_as(y)
138
139        trace = torch.jit.trace(f, (a, c))
140        s = str(trace.graph)
141        self.run_pass("peephole", trace.graph)
142        self.assertEqual(s, str(trace.graph))
143        trace = torch.jit.trace(f, (b, c))
144        self.run_pass("peephole", trace.graph)
145        self.run_pass("dce", trace.graph)
146        FileCheck().check_not("type_as").run(str(trace.graph))
147
148    @_inline_everything
149    def test_peephole_type_refinements(self):
150        def refine(x):
151            # type: (Optional[Tensor]) -> Tensor
152            return x if x is not None else torch.tensor(3)
153
154        @torch.jit.script
155        def test():
156            return refine(torch.tensor(4))
157
158        FileCheck().check("prim::unchecked_cast").run(test.graph)
159        self.run_pass("peephole", test.graph)
160        FileCheck().check_not("prim::unchecked_cast").run(test.graph)
161
162        # refinement not optimzied out
163        def is_int_tensor(x):
164            scalar = x.item()
165            if isinstance(scalar, int):
166                return scalar + 3
167            else:
168                return 8
169
170        self.checkScript(is_int_tensor, (torch.tensor(2),))
171        self.checkScript(is_int_tensor, (torch.tensor(2.5),))
172        graph = torch.jit.script(is_int_tensor).graph
173        self.run_pass("peephole", graph)
174        FileCheck().check("prim::unchecked_cast").run(graph)
175
176    def test_short_circuit_optimization(self):
177        @torch.jit.script
178        def const_expressions(x):
179            # type: (int) -> Tuple[bool, bool]
180            return x == 1 and False, x == 1 or True
181
182        self.run_pass("constant_propagation", const_expressions.graph)
183        FileCheck().check_not("prim::If").check_not("aten::eq").run(
184            const_expressions.graph
185        )
186        self.assertEqual(const_expressions(1), (False, True))
187
188        @torch.jit.script
189        def redundant_expressions(x):
190            # type: (int) -> Tuple[bool, bool]
191            return x == 1 and True, x == 1 or False
192
193        self.run_pass("peephole", redundant_expressions.graph)
194        self.assertEqual(redundant_expressions(1), (True, True))
195        self.assertEqual(redundant_expressions(0), (False, False))
196        # and True / or False are removed from graph
197        FileCheck().check("aten::eq").check_not("prim::If").run(
198            redundant_expressions.graph
199        )
200
201    def test_conv_dim_folding(self):
202        modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
203        for mod in modules:
204
205            class ConvDim(torch.nn.Module):
206                def __init__(self) -> None:
207                    super().__init__()
208                    self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False)
209
210                def forward(self, x):
211                    x = self.conv(x)
212                    return x.dim()
213
214            conv_dim = torch.jit.script(ConvDim())
215            self.run_pass("inline", conv_dim.graph)
216            self.run_pass("peephole", conv_dim.graph)
217            FileCheck().check_not("conv").check_not("dim").run(conv_dim.graph)
218
219            class ConvDimMutate(torch.nn.Module):
220                def __init__(self) -> None:
221                    super().__init__()
222                    self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False)
223
224                def forward(self, x):
225                    x = self.conv(x)
226                    x.resize_([4, 4])
227                    return x.dim()
228
229            conv_dim = torch.jit.script(ConvDimMutate())
230            self.run_pass("inline", conv_dim.graph)
231            self.run_pass("peephole", conv_dim.graph)
232            FileCheck().check("conv").check("dim").run(conv_dim.graph)
233
234    def test_normalized_rsub(self):
235        a = torch.tensor([1, 2, 3])
236        b = torch.tensor([4, 5, 6])
237
238        def convertible_rsub(x, y):
239            return (x - y), torch.rsub(y, x)
240
241        self.checkScript(convertible_rsub, (a, b))
242        op_graph = torch.jit.script(convertible_rsub).graph
243        FileCheck().check_count("aten::sub", 2, exactly=True).run(op_graph)
244        FileCheck().check_count("aten::rsub", 0, exactly=True).run(op_graph)
245
246    def test_normalized_is_op(self):
247        def convertible_is_op(x: bool, y: bool):
248            return x is True, False is x, x is y
249
250        self.checkScript(convertible_is_op, (True, False))
251
252        op_graph = torch.jit.script(convertible_is_op).graph
253        FileCheck().check_count("aten::eq", 3, exactly=True).run(op_graph)
254        FileCheck().check_count("aten::__is__", 0, exactly=True).run(op_graph)
255
256    def test_normalized_isnot_op(self):
257        def convertible_isnot_op(x: bool, y: bool):
258            return x is not True, False is not x, x is not y
259
260        self.checkScript(convertible_isnot_op, (True, False))
261
262        op_graph = torch.jit.script(convertible_isnot_op).graph
263        FileCheck().check_count("aten::ne", 3, exactly=True).run(op_graph)
264        FileCheck().check_count("aten::__isnot__", 0, exactly=True).run(op_graph)
265
266    def test_peephole_list_len(self):
267        def run_peephole_and_check_const_value(graph, const_string):
268            torch._C._jit_pass_peephole_list_idioms(graph, refine_list_len=True)
269            self.run_pass("constant_propagation", graph)
270            FileCheck().check(const_string).check_next("return").run(graph)
271
272        def gen_li(inp_len: int):
273            return [0 for i in range(inp_len)]
274
275        @torch.jit.script
276        def foo(x: List[int], y: List[int]):
277            if len(x) != 4 or len(y) != 5:
278                raise Exception("")  # noqa: TRY002
279
280            return len(x) + len(y)
281
282        run_peephole_and_check_const_value(foo.graph, "value=9")
283        self.assertEqual(foo(gen_li(4), gen_li(5)), 9)
284        with self.assertRaises(Exception):
285            foo(2, 4)
286
287        @torch.jit.script
288        def foo(x: List[int], y: List[int]):
289            if len(x) == 4 and len(y) == 5:
290                pass
291            else:
292                raise Exception("hi")  # noqa: TRY002
293
294            return len(x) + len(y)
295
296        run_peephole_and_check_const_value(foo.graph, "value=9")
297        self.assertEqual(foo(gen_li(4), gen_li(5)), 9)
298        with self.assertRaises(Exception):
299            foo(2, 4)
300
301        @torch.jit.script
302        def foo(x: List[int], y: List[int], z: List[int]):
303            if len(x) != 4:
304                raise Exception("..")  # noqa: TRY002
305            else:
306                if len(y) != 8:
307                    raise Exception("...")  # noqa: TRY002
308                else:
309                    if len(z) == 3:
310                        pass
311                    else:
312                        raise Exception("...")  # noqa: TRY002
313
314            return len(x) + len(y) * len(z)
315
316        run_peephole_and_check_const_value(foo.graph, "value=28")
317        self.assertEqual(foo(gen_li(4), gen_li(8), gen_li(3)), 28)
318        with self.assertRaises(Exception):
319            foo(1, 2, 3)
320
321        # refinement should persist in second len(x) call
322
323        @torch.jit.script
324        def foo(x: List[int], cond: bool):
325            if len(x) == 4:
326                if cond:
327                    return len(x)
328                return 4
329
330            return 4
331
332        run_peephole_and_check_const_value(foo.graph, "value=4")
333
334        def test_const_tuple_output(graph, const_inputs):
335            tup = graph.findNode("prim::TupleConstruct")
336            for i, elem in enumerate(tup.inputs()):
337                if i in const_inputs:
338                    self.assertIsNotNone(elem.toIValue())
339                else:
340                    self.assertIsNone(elem.toIValue())
341
342        # testing combinations of x1 : {True, False} x
343        # {then/else branch} x assert {True/False}
344
345        @torch.jit.script
346        def foo(x: List[int], b: List[int]):
347            if len(x) == 5:
348                x1 = True
349            else:
350                x1 = len(b) != 4
351            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
352            return len(x), len(b)
353
354        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
355        torch._C._jit_pass_constant_propagation(foo.graph)
356        # we can only infer len(b) == 4 here
357        test_const_tuple_output(foo.graph, [1])
358
359        @torch.jit.script
360        def foo(x: List[int], b: List[int]):
361            if len(x) == 5:
362                x1 = False
363            else:
364                x1 = len(b) != 4
365            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
366            return len(x), len(b)
367
368        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
369        torch._C._jit_pass_constant_propagation(foo.graph)
370        # cant infer anything
371        test_const_tuple_output(foo.graph, [])
372
373        @torch.jit.script
374        def foo(x: List[int], b: List[int]):
375            if len(x) == 5:
376                x1 = True
377            else:
378                x1 = len(b) == 4
379            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
380            return len(x), len(b)
381
382        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
383        torch._C._jit_pass_constant_propagation(foo.graph)
384        # we cant infer anything, only len(b) != 4
385        test_const_tuple_output(foo.graph, [])
386
387        @torch.jit.script
388        def foo(x: List[int], b: List[int]):
389            if len(x) == 5:
390                x1 = True
391            else:
392                x1 = len(b) != 4
393            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
394            return len(x), len(b)
395
396        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
397        torch._C._jit_pass_constant_propagation(foo.graph)
398        # can infer len(b) == 4
399        test_const_tuple_output(foo.graph, [1])
400
401        # swap branches
402        @torch.jit.script
403        def foo(x: List[int], b: List[int]):
404            if len(x) != 5:
405                x1 = len(b) != 4
406            else:
407                x1 = True
408            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
409            return len(x), len(b)
410
411        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
412        torch._C._jit_pass_constant_propagation(foo.graph)
413        # can infer len(b) == 4
414        test_const_tuple_output(foo.graph, [1])
415
416        # use __not__
417        @torch.jit.script
418        def foo(x: List[int], b: List[int]):
419            if len(x) != 5:
420                x1 = len(b) != 4
421            else:
422                x1 = True
423            assert not x1
424            return len(x), len(b)
425
426        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
427        torch._C._jit_pass_constant_propagation(foo.graph)
428        # can infer len(b) == 4
429        test_const_tuple_output(foo.graph, [1])
430
431        # Test unsuccessful optimizations
432
433        @torch.jit.script
434        def foo(x: List[int]):
435            assert len(x) == 4
436            x.append(3)
437            return len(x)
438
439        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
440        self.run_pass("constant_propagation", foo.graph)
441        FileCheck().check_count("aten::len", 2).run(foo.graph)
442
443        @torch.jit.script
444        def foo(x: List[int], y: List[int]):
445            assert len(x) == 4 or len(y) == 5
446            return len(x) + len(y)
447
448        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
449        self.run_pass("constant_propagation", foo.graph)
450        FileCheck().check_count("aten::len", 4).run(foo.graph)
451
452    def test_integer_refinement(self):
453        def run_peephole_and_check_const_value(graph, const_string):
454            self.run_pass("refine_integer_values", graph)
455            self.run_pass("constant_propagation", graph)
456            self.run_pass("dce", graph)
457            FileCheck().check(const_string).check_next("return").run(graph)
458
459        @torch.jit.script
460        def foo(x: int, y: int):
461            if x != 4 or y != 5:
462                raise Exception("")  # noqa: TRY002
463
464            return x + y
465
466        graph = foo.graph
467        self.run_pass("refine_integer_values", graph)
468        self.run_pass("constant_propagation", graph)
469        self.run_pass("dce", graph)
470
471        run_peephole_and_check_const_value(foo.graph, "value=9")
472        self.assertEqual(foo(4, 5), 9)
473        with self.assertRaises(Exception):
474            foo(2, 4)
475
476        @torch.jit.script
477        def foo(x: int, y: int):
478            if x == 4 and y == 5:
479                pass
480            else:
481                raise Exception("hi")  # noqa: TRY002
482
483            return x + y
484
485        run_peephole_and_check_const_value(foo.graph, "value=9")
486        self.assertEqual(foo(4, 5), 9)
487        with self.assertRaises(Exception):
488            foo(2, 4)
489
490        @torch.jit.script
491        def foo(x: int, y: int, z: int):
492            if x != 4:
493                raise Exception("..")  # noqa: TRY002
494            else:
495                if y != 8:
496                    raise Exception("...")  # noqa: TRY002
497                else:
498                    if z == 3:
499                        pass
500                    else:
501                        raise Exception("...")  # noqa: TRY002
502
503            return x + y * z
504
505        run_peephole_and_check_const_value(foo.graph, "value=28")
506        self.assertEqual(foo(4, 8, 3), 28)
507        with self.assertRaises(Exception):
508            foo(1, 2, 3)
509
510        # refinement should persist in second len(x) call
511
512        @torch.jit.script
513        def foo(x: int, cond: bool):
514            if x == 4:
515                if cond:
516                    return x
517                return 4
518
519            return 4
520
521        run_peephole_and_check_const_value(foo.graph, "value=4")
522
523        @torch.jit.script
524        def foo(x: int, y: int):
525            assert x == 4 or y == 5
526            return x + y
527
528        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
529        self.run_pass("constant_propagation", foo.graph)
530        FileCheck().check("aten::add").run(foo.graph)
531
532    def test_optimize_out_comparison_same_value(self):
533        def foo(x: int):
534            return x == x, x != x
535
536        def foo2(x: List[int]):
537            return x == x, x != x
538
539        for func, inp in zip([foo, foo2], [1, [2, 3]]):
540            func_s = torch.jit.script(func)
541            self.run_pass("peephole", func_s.graph)
542            FileCheck().check_not("aten::eq").check_not("aten::neq").run(func_s.graph)
543            self.assertEqual(func(inp), func_s(inp))
544
545    def test_peephole_add_zero(self):
546        @torch.jit.script
547        def foo(x: int):
548            return x + 0, 0 + x
549
550        self.run_pass("peephole", foo.graph)
551        FileCheck().check_not("aten::add")
552        self.assertEqual(foo(3), (3, 3))
553
554    def test_noop_peephole(self):
555        # test unsuccessful
556        def foo1(x):
557            return x + 0
558
559        def foo2():
560            x = torch.zeros([2, 2])
561            x.sub_(3)
562            return x + 0
563
564        def foo3():
565            x = torch.zeros([2, 2])
566            return x, x + 0
567
568        def foo4():
569            x = torch.zeros([2, 2])
570            return x + 0.0
571
572        funcs = foo1, foo2, foo3, foo4
573        inps = (torch.ones([2]),), (), (), ()
574        for func, inp in zip(funcs, inps):
575            foo_s = torch.jit.script(func)
576            self.run_pass("peephole", foo_s.graph)
577            FileCheck().check_count("aten::add", 1, exactly=True).run(foo_s.graph)
578            self.assertEqual(func(*inp), foo_s(*inp))
579
580        # successful
581        def func(x):
582            return (x + 0) * 1 - 5
583
584        func_s = torch.jit.script(func)
585        self.run_pass("peephole", func_s.graph)
586        # bail on modified value first
587        FileCheck().check_not("aten::add").check("aten::mul").run(func_s.graph)
588        # second run it should succeed
589        self.run_pass("peephole", func_s.graph)
590        FileCheck().check_not("aten::add").check_not("aten::mul").run(func_s.graph)
591        self.assertEqual(func(torch.ones([2, 2])), func_s(torch.ones([2, 2])))
592
593        def func(x):
594            return (x + 0.0) - 5
595
596        func_s = torch.jit.script(func)
597        inp = next(func_s.graph.inputs())
598        inp.setType(torch._C.TensorType.create_from_tensor(torch.rand([2, 2])))
599        torch._C._jit_pass_peephole(func_s.graph, disable_shape_peepholes=True)
600        FileCheck().check("aten::add").run(func_s.graph)
601        torch._C._jit_pass_peephole(func_s.graph, disable_shape_peepholes=False)
602        FileCheck().check_not("aten::add").run(func_s.graph)
603
604    def test_refine_integer_values(self):
605        @torch.jit.script
606        def foo(x: int):
607            y = 1
608            if x == 1:
609                return y
610            else:
611                return x
612
613        self.run_pass("refine_integer_values", foo.graph)
614        self.run_pass("constant_propagation", foo.graph)
615        self.run_pass("dce", foo.graph)
616        FileCheck().check("graph").check_next("return").run(foo.graph)
617        self.assertEqual(foo(2), 2)
618        self.assertEqual(foo(1), 1)
619
620    def test_peephole_len_list(self):
621        @torch.jit.script
622        def foo(x):
623            return len(x.size())
624
625        self.run_pass("peephole", foo.graph)
626        FileCheck().check("aten::len").run(foo.graph)
627        inputs = list(foo.graph.inputs())
628        inputs[0].setType(inputs[0].type().with_sizes([None, None]))
629        self.run_pass("peephole", foo.graph)
630        FileCheck().check_not("aten::len").run(foo.graph)
631        self.assertEqual(2, foo(torch.rand([3, 1])))
632
633        @torch.jit.script
634        def foo(x):
635            li = x.size()
636            li.append(4)
637            return len(li)
638
639        inputs = list(foo.graph.inputs())
640        inputs[0].setType(inputs[0].type().with_sizes([None, None]))
641        self.run_pass("peephole", foo.graph)
642        FileCheck().check("aten::len").run(foo.graph)
643        self.assertEqual(3, foo(torch.rand([3, 1])))
644
645    def test_peephole_optional_refine(self):
646        @torch.jit.script
647        def foo(z: int, z2: int, cond: bool):
648            if cond:
649                return z
650            else:
651                return z2
652
653        out = next(foo.graph.findNode("prim::If").outputs())
654        out.setType(torch._C.OptionalType(torch._C.IntType.get()))
655        self.run_pass("peephole", foo.graph)
656        FileCheck().check_not("int?").run(foo.graph)
657
658    def test_peephole_int(self):
659        @torch.jit.script
660        def foo(x):
661            # type: (number)
662            return int(x)
663
664        FileCheck().check("aten::Int").run(foo.graph)
665        next(foo.graph.inputs()).setType(torch._C.IntType.get())
666        self.run_pass("peephole", foo.graph)
667        FileCheck().check_not("aten::Int").run(foo.graph)
668
669    def test_peephole_arith(self):
670        @torch.jit.script
671        def foo(input0: int, input1: int, input2: int, input3: int):
672            _1 = torch.add(input1, 2)
673            _3 = torch.add(input3, 2)
674            _5 = torch.add(1, torch.sub(_1, 3) // 1)
675            _6 = torch.add(1 * torch.sub(_3, 3) // 1, 1) / 1
676            return [_5, int(_6)]
677
678        FileCheck().check("aten::add").check("aten::sub").check("aten::mul").check(
679            "aten::floordiv"
680        ).check("aten::div").run(foo.graph)
681        self.run_pass("peephole", foo.graph)
682        FileCheck().check("graph").check("):").check_next("ListConstruct").check_next(
683            "return"
684        ).run(foo.graph)
685        self.assertEqual(foo(0, 1, 2, 3), [1, 3])
686
687    def test_peephole_dict_getitem_simple(self):
688        @torch.jit.script
689        def foo(a: int, b: int):
690            d = {0: a, 1: b}
691            x = d[1]
692            y = d[0]
693            return x, y
694
695        self.run_pass("peephole", foo.graph)
696        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
697        self.assertEqual(foo(0, 1), (1, 0))
698
699        @torch.jit.script
700        def foo(a: int, b: int):
701            d = {"0": a, "1": b}
702            x = d["1"]
703            y = d["0"]
704            return x, y
705
706        self.run_pass("peephole", foo.graph)
707        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
708        self.assertEqual(foo(0, 1), (1, 0))
709
710        @torch.jit.script
711        def foo(a: int, b: int):
712            d = {0.0: a, 1.0: b}
713            x = d[1.0]
714            y = d[0.0]
715            return x, y
716
717        self.run_pass("peephole", foo.graph)
718        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
719        self.assertEqual(foo(0, 1), (1, 0))
720
721    def test_peephole_dict_getitem_no_optimization_missing_key(self):
722        @torch.jit.script
723        def foo():
724            d = {0: 1}
725            return d[2]
726
727        self.run_pass("peephole", foo.graph)
728        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
729
730    def test_peephole_dict_getitem_no_optimization_get_input_arg(self):
731        # Here we don't know if the input arg is in the dict, so we can't
732        # make the optimization.
733        @torch.jit.script
734        def foo(a: int):
735            d = {0: 1}
736            return d[a]
737
738        self.run_pass("peephole", foo.graph)
739        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
740        self.assertEqual(foo(0), 1)
741
742    def test_peephole_dict_getitem_no_optimization_dict_modified(self):
743        @torch.jit.script
744        def foo():
745            d = {0: 1}
746            d[0] = 2
747            return d[0]
748
749        self.run_pass("peephole", foo.graph)
750        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
751        self.assertEqual(foo(), 2)
752
753    def test_peephole_dict_getitem_no_optimization_overlapping_keys(self):
754        @torch.jit.script
755        def foo():
756            d = {0: 1, 0: 2}  # noqa: F601
757            return d[0]
758
759        self.run_pass("peephole", foo.graph)
760        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
761
762    def test_peephole_dict_getitem_no_optimization_keys_might_overlap(self):
763        @torch.jit.script
764        def foo(x: int):
765            d = {0: 1, x: 2}
766            return d[x]
767
768        self.run_pass("peephole", foo.graph)
769        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
770
771    def test_peephole_dict_getitem_no_optimization_unsupported_type(self):
772        @torch.jit.script
773        def foo():
774            a = torch.rand((2, 2))
775            d = {a: 1}
776            return d[a]
777
778        self.run_pass("peephole", foo.graph)
779        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
780        self.assertEqual(foo(), 1)
781
782    def test_peephole_dict_len(self):
783        @torch.jit.script
784        def foo():
785            d = {0: 1, 1: 2}
786            return len(d)
787
788        self.run_pass("peephole", foo.graph)
789        FileCheck().check_not("DictConstruct").check_not("len").run(foo.graph)
790        self.assertEqual(foo(), 2)
791
792    def test_peephole_dict_len_no_optimization_overlapping_keys(self):
793        @torch.jit.script
794        def foo():
795            d = {0: 1, 0: 2}  # noqa: F601
796            return len(d)
797
798        self.run_pass("peephole", foo.graph)
799        FileCheck().check("DictConstruct").check("len").run(foo.graph)
800        self.assertEqual(foo(), 1)
801
802    def test_peephole_dict_len_no_optimization_keys_might_overlap(self):
803        @torch.jit.script
804        def foo(x: int):
805            d = {0: 1, x: 2}
806            return len(d)
807
808        self.run_pass("peephole", foo.graph)
809        FileCheck().check("DictConstruct").check("len").run(foo.graph)
810
811    def test_peephole_dict_len_no_optimization_unsupported_type(self):
812        @torch.jit.script
813        def foo():
814            a = torch.rand((2, 2))
815            d = {a: 1}
816            return len(d)
817
818        self.run_pass("peephole", foo.graph)
819        FileCheck().check("DictConstruct").check("len").run(foo.graph)
820        self.assertEqual(foo(), 1)
821
822    def test_peephole_slice_all_three_args(self):
823        def foo(x: int):
824            return [1, 2, x, 4, 5, 6, 7][-5:6:2]
825
826        graph = torch.jit.script(foo).graph
827        self.run_pass("peephole", graph)
828        FileCheck().check_not("aten::slice").run(graph)
829        self.checkScript(foo, (3,))
830
831    def test_peephole_slice_one_empty_arg(self):
832        def check_helper(fn: Callable[[int], None]) -> None:
833            graph = torch.jit.script(fn).graph
834            self.run_pass("peephole", graph)
835            FileCheck().check_not("aten::slice").run(graph)
836            self.checkScript(fn, (3,))
837
838        def foo(x: int):
839            return [1, 2, x, 4, 5, 6, 7][1::2]
840
841        check_helper(foo)
842
843        def foo(x: int):
844            return [1, 2, x, 4, 5, 6, 7][:5:3]
845
846        check_helper(foo)
847
848        def foo(x: int):
849            return [1, 2, x, 4, 5, 6, 7][0:4]
850
851        check_helper(foo)
852
853    def test_peephole_slice_two_empty_args(self):
854        def check_helper(fn: Callable[[int], None]) -> None:
855            graph = torch.jit.script(fn).graph
856            self.run_pass("peephole", graph)
857            FileCheck().check_not("aten::slice").run(graph)
858            self.checkScript(fn, (3,))
859
860        def foo(x: int):
861            return [1, 2, x, 4, 5, 6, 7][::2]
862
863        check_helper(foo)
864
865        def foo(x: int):
866            return [1, 2, x, 4, 5, 6, 7][:5]
867
868        check_helper(foo)
869
870        def foo(x: int):
871            return [1, 2, x, 4, 5, 6, 7][1:]
872
873        check_helper(foo)
874
875    def test_peephole_slice_optimization_not_applied_list_modified(self):
876        @torch.jit.script
877        def foo():
878            li = [1, 2, 3, 4, 5, 6, 7]
879            li[0] = 0
880            return li[2:5]
881
882        self.run_pass("peephole", foo.graph)
883        FileCheck().check("aten::slice").run(foo.graph)
884
885    def test_peephole_slice_optimization_not_applied_non_const_args(self):
886        @torch.jit.script
887        def foo(x: int, y: int):
888            li = [1, 2, 3, 4, 5, 6, 7]
889            return li[x:y]
890
891        self.run_pass("peephole", foo.graph)
892        FileCheck().check("aten::slice").run(foo.graph)
893