xref: /aosp_15_r20/external/pytorch/test/jit/test_peephole.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch import nn
8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
13*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
14*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
15*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
16*da0073e9SAndroid Build Coastguard Worker        "instead."
17*da0073e9SAndroid Build Coastguard Worker    )
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerclass TestPeephole(JitTestCase):
21*da0073e9SAndroid Build Coastguard Worker    def test_peephole_with_writes(self):
22*da0073e9SAndroid Build Coastguard Worker        def test_write(x):
23*da0073e9SAndroid Build Coastguard Worker            s = 0
24*da0073e9SAndroid Build Coastguard Worker            s += x
25*da0073e9SAndroid Build Coastguard Worker            s += x
26*da0073e9SAndroid Build Coastguard Worker            return s
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker        self.checkScript(test_write, (torch.ones(4, 4),))
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    def test_peephole_with_non_output_writes(self):
31*da0073e9SAndroid Build Coastguard Worker        @torch.jit.ignore
32*da0073e9SAndroid Build Coastguard Worker        def nomnom(x):
33*da0073e9SAndroid Build Coastguard Worker            pass
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        def test_write(x):
36*da0073e9SAndroid Build Coastguard Worker            t = torch.ones_like(x)
37*da0073e9SAndroid Build Coastguard Worker            z = x.clone()
38*da0073e9SAndroid Build Coastguard Worker            y = z + 0
39*da0073e9SAndroid Build Coastguard Worker            z.add_(t)
40*da0073e9SAndroid Build Coastguard Worker            # this makes sure z isn't blasted out of existence
41*da0073e9SAndroid Build Coastguard Worker            # because it isn't returned or used in a side-effectful
42*da0073e9SAndroid Build Coastguard Worker            # way
43*da0073e9SAndroid Build Coastguard Worker            nomnom(z)
44*da0073e9SAndroid Build Coastguard Worker            return y + y
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(4, 4)
47*da0073e9SAndroid Build Coastguard Worker        j = self.checkScript(test_write, (a,))
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    def test_peephole_no_output_aliasing(self):
50*da0073e9SAndroid Build Coastguard Worker        def test_peephole(x):
51*da0073e9SAndroid Build Coastguard Worker            y = x + 0
52*da0073e9SAndroid Build Coastguard Worker            return x, y
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(4, 4)
55*da0073e9SAndroid Build Coastguard Worker        j = self.checkScript(test_peephole, (a,))
56*da0073e9SAndroid Build Coastguard Worker        r1, r2 = j(a)
57*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(r1.data_ptr(), r2.data_ptr())
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def test_peephole(self):
60*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([0.4])
61*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([0.7])
62*da0073e9SAndroid Build Coastguard Worker        c = torch.tensor([0], dtype=torch.int32)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
65*da0073e9SAndroid Build Coastguard Worker            return x.type_as(y)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker        tf = torch.jit.trace(f, (a, b))
68*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("type_as").run(str(tf.graph))
69*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", tf.graph)
70*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("type_as").run(str(tf.graph))
71*da0073e9SAndroid Build Coastguard Worker        tf2 = torch.jit.trace(f, (a, c))
72*da0073e9SAndroid Build Coastguard Worker        s = str(tf2.graph)
73*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", tf2.graph)
74*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s, str(s))
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dynamic(self):
77*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
78*da0073e9SAndroid Build Coastguard Worker            return x.type_as(y)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        fn = torch.jit.script(f)
81*da0073e9SAndroid Build Coastguard Worker        s = str(fn.graph)
82*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole(fn.graph)
83*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s, str(fn.graph))
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    def test_peephole_list_ops(self):
86*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
87*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
88*da0073e9SAndroid Build Coastguard Worker            return len([x, y, z])
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
91*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("value=3").check_next("return").run(foo.graph)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
94*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
95*da0073e9SAndroid Build Coastguard Worker            li = [x, y, z]
96*da0073e9SAndroid Build Coastguard Worker            for i in range(len(x)):
97*da0073e9SAndroid Build Coastguard Worker                li.append(x)
98*da0073e9SAndroid Build Coastguard Worker            return len([x, y, z])
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
101*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::len").run(foo.graph)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
104*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
105*da0073e9SAndroid Build Coastguard Worker            li = [x, y, z]
106*da0073e9SAndroid Build Coastguard Worker            return li[1], li[-2]
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::__getitem__").run(foo.graph)
109*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
110*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::__getitem__").run(foo.graph)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
113*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
114*da0073e9SAndroid Build Coastguard Worker            li = [x, y, z]
115*da0073e9SAndroid Build Coastguard Worker            return li[-7]
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
118*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::__getitem__").run(foo.graph)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
121*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
122*da0073e9SAndroid Build Coastguard Worker            li = [x, y, z]
123*da0073e9SAndroid Build Coastguard Worker            for i in range(len(x)):
124*da0073e9SAndroid Build Coastguard Worker                li.append(x)
125*da0073e9SAndroid Build Coastguard Worker            return li[-2]
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
128*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::__getitem__").run(foo.graph)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
131*da0073e9SAndroid Build Coastguard Worker    def test_peephole_cuda(self):
132*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([0.4], device="cpu")
133*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([0.7], device="cuda")
134*da0073e9SAndroid Build Coastguard Worker        c = torch.tensor([0.7], device="cuda")
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
137*da0073e9SAndroid Build Coastguard Worker            return x.type_as(y)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker        trace = torch.jit.trace(f, (a, c))
140*da0073e9SAndroid Build Coastguard Worker        s = str(trace.graph)
141*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", trace.graph)
142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s, str(trace.graph))
143*da0073e9SAndroid Build Coastguard Worker        trace = torch.jit.trace(f, (b, c))
144*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", trace.graph)
145*da0073e9SAndroid Build Coastguard Worker        self.run_pass("dce", trace.graph)
146*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("type_as").run(str(trace.graph))
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    @_inline_everything
149*da0073e9SAndroid Build Coastguard Worker    def test_peephole_type_refinements(self):
150*da0073e9SAndroid Build Coastguard Worker        def refine(x):
151*da0073e9SAndroid Build Coastguard Worker            # type: (Optional[Tensor]) -> Tensor
152*da0073e9SAndroid Build Coastguard Worker            return x if x is not None else torch.tensor(3)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
155*da0073e9SAndroid Build Coastguard Worker        def test():
156*da0073e9SAndroid Build Coastguard Worker            return refine(torch.tensor(4))
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::unchecked_cast").run(test.graph)
159*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", test.graph)
160*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::unchecked_cast").run(test.graph)
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker        # refinement not optimzied out
163*da0073e9SAndroid Build Coastguard Worker        def is_int_tensor(x):
164*da0073e9SAndroid Build Coastguard Worker            scalar = x.item()
165*da0073e9SAndroid Build Coastguard Worker            if isinstance(scalar, int):
166*da0073e9SAndroid Build Coastguard Worker                return scalar + 3
167*da0073e9SAndroid Build Coastguard Worker            else:
168*da0073e9SAndroid Build Coastguard Worker                return 8
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        self.checkScript(is_int_tensor, (torch.tensor(2),))
171*da0073e9SAndroid Build Coastguard Worker        self.checkScript(is_int_tensor, (torch.tensor(2.5),))
172*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(is_int_tensor).graph
173*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", graph)
174*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("prim::unchecked_cast").run(graph)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    def test_short_circuit_optimization(self):
177*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
178*da0073e9SAndroid Build Coastguard Worker        def const_expressions(x):
179*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> Tuple[bool, bool]
180*da0073e9SAndroid Build Coastguard Worker            return x == 1 and False, x == 1 or True
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        self.run_pass("constant_propagation", const_expressions.graph)
183*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("prim::If").check_not("aten::eq").run(
184*da0073e9SAndroid Build Coastguard Worker            const_expressions.graph
185*da0073e9SAndroid Build Coastguard Worker        )
186*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(const_expressions(1), (False, True))
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
189*da0073e9SAndroid Build Coastguard Worker        def redundant_expressions(x):
190*da0073e9SAndroid Build Coastguard Worker            # type: (int) -> Tuple[bool, bool]
191*da0073e9SAndroid Build Coastguard Worker            return x == 1 and True, x == 1 or False
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", redundant_expressions.graph)
194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(redundant_expressions(1), (True, True))
195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(redundant_expressions(0), (False, False))
196*da0073e9SAndroid Build Coastguard Worker        # and True / or False are removed from graph
197*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::eq").check_not("prim::If").run(
198*da0073e9SAndroid Build Coastguard Worker            redundant_expressions.graph
199*da0073e9SAndroid Build Coastguard Worker        )
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    def test_conv_dim_folding(self):
202*da0073e9SAndroid Build Coastguard Worker        modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
203*da0073e9SAndroid Build Coastguard Worker        for mod in modules:
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker            class ConvDim(torch.nn.Module):
206*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
207*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
208*da0073e9SAndroid Build Coastguard Worker                    self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
211*da0073e9SAndroid Build Coastguard Worker                    x = self.conv(x)
212*da0073e9SAndroid Build Coastguard Worker                    return x.dim()
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker            conv_dim = torch.jit.script(ConvDim())
215*da0073e9SAndroid Build Coastguard Worker            self.run_pass("inline", conv_dim.graph)
216*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", conv_dim.graph)
217*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("conv").check_not("dim").run(conv_dim.graph)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker            class ConvDimMutate(torch.nn.Module):
220*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
221*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
222*da0073e9SAndroid Build Coastguard Worker                    self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
225*da0073e9SAndroid Build Coastguard Worker                    x = self.conv(x)
226*da0073e9SAndroid Build Coastguard Worker                    x.resize_([4, 4])
227*da0073e9SAndroid Build Coastguard Worker                    return x.dim()
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker            conv_dim = torch.jit.script(ConvDimMutate())
230*da0073e9SAndroid Build Coastguard Worker            self.run_pass("inline", conv_dim.graph)
231*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", conv_dim.graph)
232*da0073e9SAndroid Build Coastguard Worker            FileCheck().check("conv").check("dim").run(conv_dim.graph)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    def test_normalized_rsub(self):
235*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1, 2, 3])
236*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([4, 5, 6])
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        def convertible_rsub(x, y):
239*da0073e9SAndroid Build Coastguard Worker            return (x - y), torch.rsub(y, x)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker        self.checkScript(convertible_rsub, (a, b))
242*da0073e9SAndroid Build Coastguard Worker        op_graph = torch.jit.script(convertible_rsub).graph
243*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::sub", 2, exactly=True).run(op_graph)
244*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::rsub", 0, exactly=True).run(op_graph)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    def test_normalized_is_op(self):
247*da0073e9SAndroid Build Coastguard Worker        def convertible_is_op(x: bool, y: bool):
248*da0073e9SAndroid Build Coastguard Worker            return x is True, False is x, x is y
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        self.checkScript(convertible_is_op, (True, False))
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        op_graph = torch.jit.script(convertible_is_op).graph
253*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::eq", 3, exactly=True).run(op_graph)
254*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::__is__", 0, exactly=True).run(op_graph)
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker    def test_normalized_isnot_op(self):
257*da0073e9SAndroid Build Coastguard Worker        def convertible_isnot_op(x: bool, y: bool):
258*da0073e9SAndroid Build Coastguard Worker            return x is not True, False is not x, x is not y
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        self.checkScript(convertible_isnot_op, (True, False))
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        op_graph = torch.jit.script(convertible_isnot_op).graph
263*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::ne", 3, exactly=True).run(op_graph)
264*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::__isnot__", 0, exactly=True).run(op_graph)
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    def test_peephole_list_len(self):
267*da0073e9SAndroid Build Coastguard Worker        def run_peephole_and_check_const_value(graph, const_string):
268*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_pass_peephole_list_idioms(graph, refine_list_len=True)
269*da0073e9SAndroid Build Coastguard Worker            self.run_pass("constant_propagation", graph)
270*da0073e9SAndroid Build Coastguard Worker            FileCheck().check(const_string).check_next("return").run(graph)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        def gen_li(inp_len: int):
273*da0073e9SAndroid Build Coastguard Worker            return [0 for i in range(inp_len)]
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
276*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], y: List[int]):
277*da0073e9SAndroid Build Coastguard Worker            if len(x) != 4 or len(y) != 5:
278*da0073e9SAndroid Build Coastguard Worker                raise Exception("")  # noqa: TRY002
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker            return len(x) + len(y)
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker        run_peephole_and_check_const_value(foo.graph, "value=9")
283*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(gen_li(4), gen_li(5)), 9)
284*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
285*da0073e9SAndroid Build Coastguard Worker            foo(2, 4)
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
288*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], y: List[int]):
289*da0073e9SAndroid Build Coastguard Worker            if len(x) == 4 and len(y) == 5:
290*da0073e9SAndroid Build Coastguard Worker                pass
291*da0073e9SAndroid Build Coastguard Worker            else:
292*da0073e9SAndroid Build Coastguard Worker                raise Exception("hi")  # noqa: TRY002
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker            return len(x) + len(y)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        run_peephole_and_check_const_value(foo.graph, "value=9")
297*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(gen_li(4), gen_li(5)), 9)
298*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
299*da0073e9SAndroid Build Coastguard Worker            foo(2, 4)
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
302*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], y: List[int], z: List[int]):
303*da0073e9SAndroid Build Coastguard Worker            if len(x) != 4:
304*da0073e9SAndroid Build Coastguard Worker                raise Exception("..")  # noqa: TRY002
305*da0073e9SAndroid Build Coastguard Worker            else:
306*da0073e9SAndroid Build Coastguard Worker                if len(y) != 8:
307*da0073e9SAndroid Build Coastguard Worker                    raise Exception("...")  # noqa: TRY002
308*da0073e9SAndroid Build Coastguard Worker                else:
309*da0073e9SAndroid Build Coastguard Worker                    if len(z) == 3:
310*da0073e9SAndroid Build Coastguard Worker                        pass
311*da0073e9SAndroid Build Coastguard Worker                    else:
312*da0073e9SAndroid Build Coastguard Worker                        raise Exception("...")  # noqa: TRY002
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker            return len(x) + len(y) * len(z)
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        run_peephole_and_check_const_value(foo.graph, "value=28")
317*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(gen_li(4), gen_li(8), gen_li(3)), 28)
318*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
319*da0073e9SAndroid Build Coastguard Worker            foo(1, 2, 3)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        # refinement should persist in second len(x) call
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
324*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], cond: bool):
325*da0073e9SAndroid Build Coastguard Worker            if len(x) == 4:
326*da0073e9SAndroid Build Coastguard Worker                if cond:
327*da0073e9SAndroid Build Coastguard Worker                    return len(x)
328*da0073e9SAndroid Build Coastguard Worker                return 4
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker            return 4
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker        run_peephole_and_check_const_value(foo.graph, "value=4")
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        def test_const_tuple_output(graph, const_inputs):
335*da0073e9SAndroid Build Coastguard Worker            tup = graph.findNode("prim::TupleConstruct")
336*da0073e9SAndroid Build Coastguard Worker            for i, elem in enumerate(tup.inputs()):
337*da0073e9SAndroid Build Coastguard Worker                if i in const_inputs:
338*da0073e9SAndroid Build Coastguard Worker                    self.assertIsNotNone(elem.toIValue())
339*da0073e9SAndroid Build Coastguard Worker                else:
340*da0073e9SAndroid Build Coastguard Worker                    self.assertIsNone(elem.toIValue())
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        # testing combinations of x1 : {True, False} x
343*da0073e9SAndroid Build Coastguard Worker        # {then/else branch} x assert {True/False}
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
346*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], b: List[int]):
347*da0073e9SAndroid Build Coastguard Worker            if len(x) == 5:
348*da0073e9SAndroid Build Coastguard Worker                x1 = True
349*da0073e9SAndroid Build Coastguard Worker            else:
350*da0073e9SAndroid Build Coastguard Worker                x1 = len(b) != 4
351*da0073e9SAndroid Build Coastguard Worker            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
352*da0073e9SAndroid Build Coastguard Worker            return len(x), len(b)
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
355*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_constant_propagation(foo.graph)
356*da0073e9SAndroid Build Coastguard Worker        # we can only infer len(b) == 4 here
357*da0073e9SAndroid Build Coastguard Worker        test_const_tuple_output(foo.graph, [1])
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
360*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], b: List[int]):
361*da0073e9SAndroid Build Coastguard Worker            if len(x) == 5:
362*da0073e9SAndroid Build Coastguard Worker                x1 = False
363*da0073e9SAndroid Build Coastguard Worker            else:
364*da0073e9SAndroid Build Coastguard Worker                x1 = len(b) != 4
365*da0073e9SAndroid Build Coastguard Worker            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
366*da0073e9SAndroid Build Coastguard Worker            return len(x), len(b)
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
369*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_constant_propagation(foo.graph)
370*da0073e9SAndroid Build Coastguard Worker        # cant infer anything
371*da0073e9SAndroid Build Coastguard Worker        test_const_tuple_output(foo.graph, [])
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
374*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], b: List[int]):
375*da0073e9SAndroid Build Coastguard Worker            if len(x) == 5:
376*da0073e9SAndroid Build Coastguard Worker                x1 = True
377*da0073e9SAndroid Build Coastguard Worker            else:
378*da0073e9SAndroid Build Coastguard Worker                x1 = len(b) == 4
379*da0073e9SAndroid Build Coastguard Worker            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
380*da0073e9SAndroid Build Coastguard Worker            return len(x), len(b)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
383*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_constant_propagation(foo.graph)
384*da0073e9SAndroid Build Coastguard Worker        # we cant infer anything, only len(b) != 4
385*da0073e9SAndroid Build Coastguard Worker        test_const_tuple_output(foo.graph, [])
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
388*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], b: List[int]):
389*da0073e9SAndroid Build Coastguard Worker            if len(x) == 5:
390*da0073e9SAndroid Build Coastguard Worker                x1 = True
391*da0073e9SAndroid Build Coastguard Worker            else:
392*da0073e9SAndroid Build Coastguard Worker                x1 = len(b) != 4
393*da0073e9SAndroid Build Coastguard Worker            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
394*da0073e9SAndroid Build Coastguard Worker            return len(x), len(b)
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
397*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_constant_propagation(foo.graph)
398*da0073e9SAndroid Build Coastguard Worker        # can infer len(b) == 4
399*da0073e9SAndroid Build Coastguard Worker        test_const_tuple_output(foo.graph, [1])
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        # swap branches
402*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
403*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], b: List[int]):
404*da0073e9SAndroid Build Coastguard Worker            if len(x) != 5:
405*da0073e9SAndroid Build Coastguard Worker                x1 = len(b) != 4
406*da0073e9SAndroid Build Coastguard Worker            else:
407*da0073e9SAndroid Build Coastguard Worker                x1 = True
408*da0073e9SAndroid Build Coastguard Worker            assert x1 == False  # noqa: E712 TODO: canonicalize x is False to aten::eq
409*da0073e9SAndroid Build Coastguard Worker            return len(x), len(b)
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
412*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_constant_propagation(foo.graph)
413*da0073e9SAndroid Build Coastguard Worker        # can infer len(b) == 4
414*da0073e9SAndroid Build Coastguard Worker        test_const_tuple_output(foo.graph, [1])
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        # use __not__
417*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
418*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], b: List[int]):
419*da0073e9SAndroid Build Coastguard Worker            if len(x) != 5:
420*da0073e9SAndroid Build Coastguard Worker                x1 = len(b) != 4
421*da0073e9SAndroid Build Coastguard Worker            else:
422*da0073e9SAndroid Build Coastguard Worker                x1 = True
423*da0073e9SAndroid Build Coastguard Worker            assert not x1
424*da0073e9SAndroid Build Coastguard Worker            return len(x), len(b)
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
427*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_constant_propagation(foo.graph)
428*da0073e9SAndroid Build Coastguard Worker        # can infer len(b) == 4
429*da0073e9SAndroid Build Coastguard Worker        test_const_tuple_output(foo.graph, [1])
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker        # Test unsuccessful optimizations
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
434*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int]):
435*da0073e9SAndroid Build Coastguard Worker            assert len(x) == 4
436*da0073e9SAndroid Build Coastguard Worker            x.append(3)
437*da0073e9SAndroid Build Coastguard Worker            return len(x)
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
440*da0073e9SAndroid Build Coastguard Worker        self.run_pass("constant_propagation", foo.graph)
441*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::len", 2).run(foo.graph)
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
444*da0073e9SAndroid Build Coastguard Worker        def foo(x: List[int], y: List[int]):
445*da0073e9SAndroid Build Coastguard Worker            assert len(x) == 4 or len(y) == 5
446*da0073e9SAndroid Build Coastguard Worker            return len(x) + len(y)
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
449*da0073e9SAndroid Build Coastguard Worker        self.run_pass("constant_propagation", foo.graph)
450*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_count("aten::len", 4).run(foo.graph)
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    def test_integer_refinement(self):
453*da0073e9SAndroid Build Coastguard Worker        def run_peephole_and_check_const_value(graph, const_string):
454*da0073e9SAndroid Build Coastguard Worker            self.run_pass("refine_integer_values", graph)
455*da0073e9SAndroid Build Coastguard Worker            self.run_pass("constant_propagation", graph)
456*da0073e9SAndroid Build Coastguard Worker            self.run_pass("dce", graph)
457*da0073e9SAndroid Build Coastguard Worker            FileCheck().check(const_string).check_next("return").run(graph)
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
460*da0073e9SAndroid Build Coastguard Worker        def foo(x: int, y: int):
461*da0073e9SAndroid Build Coastguard Worker            if x != 4 or y != 5:
462*da0073e9SAndroid Build Coastguard Worker                raise Exception("")  # noqa: TRY002
463*da0073e9SAndroid Build Coastguard Worker
464*da0073e9SAndroid Build Coastguard Worker            return x + y
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker        graph = foo.graph
467*da0073e9SAndroid Build Coastguard Worker        self.run_pass("refine_integer_values", graph)
468*da0073e9SAndroid Build Coastguard Worker        self.run_pass("constant_propagation", graph)
469*da0073e9SAndroid Build Coastguard Worker        self.run_pass("dce", graph)
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        run_peephole_and_check_const_value(foo.graph, "value=9")
472*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(4, 5), 9)
473*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
474*da0073e9SAndroid Build Coastguard Worker            foo(2, 4)
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
477*da0073e9SAndroid Build Coastguard Worker        def foo(x: int, y: int):
478*da0073e9SAndroid Build Coastguard Worker            if x == 4 and y == 5:
479*da0073e9SAndroid Build Coastguard Worker                pass
480*da0073e9SAndroid Build Coastguard Worker            else:
481*da0073e9SAndroid Build Coastguard Worker                raise Exception("hi")  # noqa: TRY002
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            return x + y
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker        run_peephole_and_check_const_value(foo.graph, "value=9")
486*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(4, 5), 9)
487*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
488*da0073e9SAndroid Build Coastguard Worker            foo(2, 4)
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
491*da0073e9SAndroid Build Coastguard Worker        def foo(x: int, y: int, z: int):
492*da0073e9SAndroid Build Coastguard Worker            if x != 4:
493*da0073e9SAndroid Build Coastguard Worker                raise Exception("..")  # noqa: TRY002
494*da0073e9SAndroid Build Coastguard Worker            else:
495*da0073e9SAndroid Build Coastguard Worker                if y != 8:
496*da0073e9SAndroid Build Coastguard Worker                    raise Exception("...")  # noqa: TRY002
497*da0073e9SAndroid Build Coastguard Worker                else:
498*da0073e9SAndroid Build Coastguard Worker                    if z == 3:
499*da0073e9SAndroid Build Coastguard Worker                        pass
500*da0073e9SAndroid Build Coastguard Worker                    else:
501*da0073e9SAndroid Build Coastguard Worker                        raise Exception("...")  # noqa: TRY002
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker            return x + y * z
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        run_peephole_and_check_const_value(foo.graph, "value=28")
506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(4, 8, 3), 28)
507*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(Exception):
508*da0073e9SAndroid Build Coastguard Worker            foo(1, 2, 3)
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker        # refinement should persist in second len(x) call
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
513*da0073e9SAndroid Build Coastguard Worker        def foo(x: int, cond: bool):
514*da0073e9SAndroid Build Coastguard Worker            if x == 4:
515*da0073e9SAndroid Build Coastguard Worker                if cond:
516*da0073e9SAndroid Build Coastguard Worker                    return x
517*da0073e9SAndroid Build Coastguard Worker                return 4
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker            return 4
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker        run_peephole_and_check_const_value(foo.graph, "value=4")
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
524*da0073e9SAndroid Build Coastguard Worker        def foo(x: int, y: int):
525*da0073e9SAndroid Build Coastguard Worker            assert x == 4 or y == 5
526*da0073e9SAndroid Build Coastguard Worker            return x + y
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True)
529*da0073e9SAndroid Build Coastguard Worker        self.run_pass("constant_propagation", foo.graph)
530*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::add").run(foo.graph)
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker    def test_optimize_out_comparison_same_value(self):
533*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
534*da0073e9SAndroid Build Coastguard Worker            return x == x, x != x
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker        def foo2(x: List[int]):
537*da0073e9SAndroid Build Coastguard Worker            return x == x, x != x
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker        for func, inp in zip([foo, foo2], [1, [2, 3]]):
540*da0073e9SAndroid Build Coastguard Worker            func_s = torch.jit.script(func)
541*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", func_s.graph)
542*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("aten::eq").check_not("aten::neq").run(func_s.graph)
543*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(inp), func_s(inp))
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker    def test_peephole_add_zero(self):
546*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
547*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
548*da0073e9SAndroid Build Coastguard Worker            return x + 0, 0 + x
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
551*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::add")
552*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(3), (3, 3))
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker    def test_noop_peephole(self):
555*da0073e9SAndroid Build Coastguard Worker        # test unsuccessful
556*da0073e9SAndroid Build Coastguard Worker        def foo1(x):
557*da0073e9SAndroid Build Coastguard Worker            return x + 0
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker        def foo2():
560*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros([2, 2])
561*da0073e9SAndroid Build Coastguard Worker            x.sub_(3)
562*da0073e9SAndroid Build Coastguard Worker            return x + 0
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker        def foo3():
565*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros([2, 2])
566*da0073e9SAndroid Build Coastguard Worker            return x, x + 0
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker        def foo4():
569*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros([2, 2])
570*da0073e9SAndroid Build Coastguard Worker            return x + 0.0
571*da0073e9SAndroid Build Coastguard Worker
572*da0073e9SAndroid Build Coastguard Worker        funcs = foo1, foo2, foo3, foo4
573*da0073e9SAndroid Build Coastguard Worker        inps = (torch.ones([2]),), (), (), ()
574*da0073e9SAndroid Build Coastguard Worker        for func, inp in zip(funcs, inps):
575*da0073e9SAndroid Build Coastguard Worker            foo_s = torch.jit.script(func)
576*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", foo_s.graph)
577*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_count("aten::add", 1, exactly=True).run(foo_s.graph)
578*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(func(*inp), foo_s(*inp))
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker        # successful
581*da0073e9SAndroid Build Coastguard Worker        def func(x):
582*da0073e9SAndroid Build Coastguard Worker            return (x + 0) * 1 - 5
583*da0073e9SAndroid Build Coastguard Worker
584*da0073e9SAndroid Build Coastguard Worker        func_s = torch.jit.script(func)
585*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", func_s.graph)
586*da0073e9SAndroid Build Coastguard Worker        # bail on modified value first
587*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::add").check("aten::mul").run(func_s.graph)
588*da0073e9SAndroid Build Coastguard Worker        # second run it should succeed
589*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", func_s.graph)
590*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::add").check_not("aten::mul").run(func_s.graph)
591*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(func(torch.ones([2, 2])), func_s(torch.ones([2, 2])))
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker        def func(x):
594*da0073e9SAndroid Build Coastguard Worker            return (x + 0.0) - 5
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker        func_s = torch.jit.script(func)
597*da0073e9SAndroid Build Coastguard Worker        inp = next(func_s.graph.inputs())
598*da0073e9SAndroid Build Coastguard Worker        inp.setType(torch._C.TensorType.create_from_tensor(torch.rand([2, 2])))
599*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole(func_s.graph, disable_shape_peepholes=True)
600*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::add").run(func_s.graph)
601*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_peephole(func_s.graph, disable_shape_peepholes=False)
602*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::add").run(func_s.graph)
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker    def test_refine_integer_values(self):
605*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
606*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
607*da0073e9SAndroid Build Coastguard Worker            y = 1
608*da0073e9SAndroid Build Coastguard Worker            if x == 1:
609*da0073e9SAndroid Build Coastguard Worker                return y
610*da0073e9SAndroid Build Coastguard Worker            else:
611*da0073e9SAndroid Build Coastguard Worker                return x
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        self.run_pass("refine_integer_values", foo.graph)
614*da0073e9SAndroid Build Coastguard Worker        self.run_pass("constant_propagation", foo.graph)
615*da0073e9SAndroid Build Coastguard Worker        self.run_pass("dce", foo.graph)
616*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("graph").check_next("return").run(foo.graph)
617*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(2), 2)
618*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(1), 1)
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker    def test_peephole_len_list(self):
621*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
622*da0073e9SAndroid Build Coastguard Worker        def foo(x):
623*da0073e9SAndroid Build Coastguard Worker            return len(x.size())
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
626*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::len").run(foo.graph)
627*da0073e9SAndroid Build Coastguard Worker        inputs = list(foo.graph.inputs())
628*da0073e9SAndroid Build Coastguard Worker        inputs[0].setType(inputs[0].type().with_sizes([None, None]))
629*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
630*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::len").run(foo.graph)
631*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, foo(torch.rand([3, 1])))
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
634*da0073e9SAndroid Build Coastguard Worker        def foo(x):
635*da0073e9SAndroid Build Coastguard Worker            li = x.size()
636*da0073e9SAndroid Build Coastguard Worker            li.append(4)
637*da0073e9SAndroid Build Coastguard Worker            return len(li)
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker        inputs = list(foo.graph.inputs())
640*da0073e9SAndroid Build Coastguard Worker        inputs[0].setType(inputs[0].type().with_sizes([None, None]))
641*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
642*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::len").run(foo.graph)
643*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(3, foo(torch.rand([3, 1])))
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker    def test_peephole_optional_refine(self):
646*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
647*da0073e9SAndroid Build Coastguard Worker        def foo(z: int, z2: int, cond: bool):
648*da0073e9SAndroid Build Coastguard Worker            if cond:
649*da0073e9SAndroid Build Coastguard Worker                return z
650*da0073e9SAndroid Build Coastguard Worker            else:
651*da0073e9SAndroid Build Coastguard Worker                return z2
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        out = next(foo.graph.findNode("prim::If").outputs())
654*da0073e9SAndroid Build Coastguard Worker        out.setType(torch._C.OptionalType(torch._C.IntType.get()))
655*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
656*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("int?").run(foo.graph)
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker    def test_peephole_int(self):
659*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
660*da0073e9SAndroid Build Coastguard Worker        def foo(x):
661*da0073e9SAndroid Build Coastguard Worker            # type: (number)
662*da0073e9SAndroid Build Coastguard Worker            return int(x)
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::Int").run(foo.graph)
665*da0073e9SAndroid Build Coastguard Worker        next(foo.graph.inputs()).setType(torch._C.IntType.get())
666*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
667*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::Int").run(foo.graph)
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    def test_peephole_arith(self):
670*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
671*da0073e9SAndroid Build Coastguard Worker        def foo(input0: int, input1: int, input2: int, input3: int):
672*da0073e9SAndroid Build Coastguard Worker            _1 = torch.add(input1, 2)
673*da0073e9SAndroid Build Coastguard Worker            _3 = torch.add(input3, 2)
674*da0073e9SAndroid Build Coastguard Worker            _5 = torch.add(1, torch.sub(_1, 3) // 1)
675*da0073e9SAndroid Build Coastguard Worker            _6 = torch.add(1 * torch.sub(_3, 3) // 1, 1) / 1
676*da0073e9SAndroid Build Coastguard Worker            return [_5, int(_6)]
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::add").check("aten::sub").check("aten::mul").check(
679*da0073e9SAndroid Build Coastguard Worker            "aten::floordiv"
680*da0073e9SAndroid Build Coastguard Worker        ).check("aten::div").run(foo.graph)
681*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
682*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("graph").check("):").check_next("ListConstruct").check_next(
683*da0073e9SAndroid Build Coastguard Worker            "return"
684*da0073e9SAndroid Build Coastguard Worker        ).run(foo.graph)
685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(0, 1, 2, 3), [1, 3])
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_getitem_simple(self):
688*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
689*da0073e9SAndroid Build Coastguard Worker        def foo(a: int, b: int):
690*da0073e9SAndroid Build Coastguard Worker            d = {0: a, 1: b}
691*da0073e9SAndroid Build Coastguard Worker            x = d[1]
692*da0073e9SAndroid Build Coastguard Worker            y = d[0]
693*da0073e9SAndroid Build Coastguard Worker            return x, y
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
696*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
697*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(0, 1), (1, 0))
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
700*da0073e9SAndroid Build Coastguard Worker        def foo(a: int, b: int):
701*da0073e9SAndroid Build Coastguard Worker            d = {"0": a, "1": b}
702*da0073e9SAndroid Build Coastguard Worker            x = d["1"]
703*da0073e9SAndroid Build Coastguard Worker            y = d["0"]
704*da0073e9SAndroid Build Coastguard Worker            return x, y
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
707*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(0, 1), (1, 0))
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
711*da0073e9SAndroid Build Coastguard Worker        def foo(a: int, b: int):
712*da0073e9SAndroid Build Coastguard Worker            d = {0.0: a, 1.0: b}
713*da0073e9SAndroid Build Coastguard Worker            x = d[1.0]
714*da0073e9SAndroid Build Coastguard Worker            y = d[0.0]
715*da0073e9SAndroid Build Coastguard Worker            return x, y
716*da0073e9SAndroid Build Coastguard Worker
717*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
718*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph)
719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(0, 1), (1, 0))
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_getitem_no_optimization_missing_key(self):
722*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
723*da0073e9SAndroid Build Coastguard Worker        def foo():
724*da0073e9SAndroid Build Coastguard Worker            d = {0: 1}
725*da0073e9SAndroid Build Coastguard Worker            return d[2]
726*da0073e9SAndroid Build Coastguard Worker
727*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
728*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_getitem_no_optimization_get_input_arg(self):
731*da0073e9SAndroid Build Coastguard Worker        # Here we don't know if the input arg is in the dict, so we can't
732*da0073e9SAndroid Build Coastguard Worker        # make the optimization.
733*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
734*da0073e9SAndroid Build Coastguard Worker        def foo(a: int):
735*da0073e9SAndroid Build Coastguard Worker            d = {0: 1}
736*da0073e9SAndroid Build Coastguard Worker            return d[a]
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
739*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
740*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(0), 1)
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_getitem_no_optimization_dict_modified(self):
743*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
744*da0073e9SAndroid Build Coastguard Worker        def foo():
745*da0073e9SAndroid Build Coastguard Worker            d = {0: 1}
746*da0073e9SAndroid Build Coastguard Worker            d[0] = 2
747*da0073e9SAndroid Build Coastguard Worker            return d[0]
748*da0073e9SAndroid Build Coastguard Worker
749*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
750*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
751*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 2)
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_getitem_no_optimization_overlapping_keys(self):
754*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
755*da0073e9SAndroid Build Coastguard Worker        def foo():
756*da0073e9SAndroid Build Coastguard Worker            d = {0: 1, 0: 2}  # noqa: F601
757*da0073e9SAndroid Build Coastguard Worker            return d[0]
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
760*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_getitem_no_optimization_keys_might_overlap(self):
763*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
764*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
765*da0073e9SAndroid Build Coastguard Worker            d = {0: 1, x: 2}
766*da0073e9SAndroid Build Coastguard Worker            return d[x]
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
769*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_getitem_no_optimization_unsupported_type(self):
772*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
773*da0073e9SAndroid Build Coastguard Worker        def foo():
774*da0073e9SAndroid Build Coastguard Worker            a = torch.rand((2, 2))
775*da0073e9SAndroid Build Coastguard Worker            d = {a: 1}
776*da0073e9SAndroid Build Coastguard Worker            return d[a]
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
779*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph)
780*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 1)
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_len(self):
783*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
784*da0073e9SAndroid Build Coastguard Worker        def foo():
785*da0073e9SAndroid Build Coastguard Worker            d = {0: 1, 1: 2}
786*da0073e9SAndroid Build Coastguard Worker            return len(d)
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
789*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("DictConstruct").check_not("len").run(foo.graph)
790*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 2)
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_len_no_optimization_overlapping_keys(self):
793*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
794*da0073e9SAndroid Build Coastguard Worker        def foo():
795*da0073e9SAndroid Build Coastguard Worker            d = {0: 1, 0: 2}  # noqa: F601
796*da0073e9SAndroid Build Coastguard Worker            return len(d)
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
799*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("len").run(foo.graph)
800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 1)
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_len_no_optimization_keys_might_overlap(self):
803*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
804*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
805*da0073e9SAndroid Build Coastguard Worker            d = {0: 1, x: 2}
806*da0073e9SAndroid Build Coastguard Worker            return len(d)
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
809*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("len").run(foo.graph)
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker    def test_peephole_dict_len_no_optimization_unsupported_type(self):
812*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
813*da0073e9SAndroid Build Coastguard Worker        def foo():
814*da0073e9SAndroid Build Coastguard Worker            a = torch.rand((2, 2))
815*da0073e9SAndroid Build Coastguard Worker            d = {a: 1}
816*da0073e9SAndroid Build Coastguard Worker            return len(d)
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
819*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("DictConstruct").check("len").run(foo.graph)
820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo(), 1)
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker    def test_peephole_slice_all_three_args(self):
823*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
824*da0073e9SAndroid Build Coastguard Worker            return [1, 2, x, 4, 5, 6, 7][-5:6:2]
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker        graph = torch.jit.script(foo).graph
827*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", graph)
828*da0073e9SAndroid Build Coastguard Worker        FileCheck().check_not("aten::slice").run(graph)
829*da0073e9SAndroid Build Coastguard Worker        self.checkScript(foo, (3,))
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker    def test_peephole_slice_one_empty_arg(self):
832*da0073e9SAndroid Build Coastguard Worker        def check_helper(fn: Callable[[int], None]) -> None:
833*da0073e9SAndroid Build Coastguard Worker            graph = torch.jit.script(fn).graph
834*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", graph)
835*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("aten::slice").run(graph)
836*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn, (3,))
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
839*da0073e9SAndroid Build Coastguard Worker            return [1, 2, x, 4, 5, 6, 7][1::2]
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker        check_helper(foo)
842*da0073e9SAndroid Build Coastguard Worker
843*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
844*da0073e9SAndroid Build Coastguard Worker            return [1, 2, x, 4, 5, 6, 7][:5:3]
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker        check_helper(foo)
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
849*da0073e9SAndroid Build Coastguard Worker            return [1, 2, x, 4, 5, 6, 7][0:4]
850*da0073e9SAndroid Build Coastguard Worker
851*da0073e9SAndroid Build Coastguard Worker        check_helper(foo)
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker    def test_peephole_slice_two_empty_args(self):
854*da0073e9SAndroid Build Coastguard Worker        def check_helper(fn: Callable[[int], None]) -> None:
855*da0073e9SAndroid Build Coastguard Worker            graph = torch.jit.script(fn).graph
856*da0073e9SAndroid Build Coastguard Worker            self.run_pass("peephole", graph)
857*da0073e9SAndroid Build Coastguard Worker            FileCheck().check_not("aten::slice").run(graph)
858*da0073e9SAndroid Build Coastguard Worker            self.checkScript(fn, (3,))
859*da0073e9SAndroid Build Coastguard Worker
860*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
861*da0073e9SAndroid Build Coastguard Worker            return [1, 2, x, 4, 5, 6, 7][::2]
862*da0073e9SAndroid Build Coastguard Worker
863*da0073e9SAndroid Build Coastguard Worker        check_helper(foo)
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
866*da0073e9SAndroid Build Coastguard Worker            return [1, 2, x, 4, 5, 6, 7][:5]
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker        check_helper(foo)
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker        def foo(x: int):
871*da0073e9SAndroid Build Coastguard Worker            return [1, 2, x, 4, 5, 6, 7][1:]
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker        check_helper(foo)
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker    def test_peephole_slice_optimization_not_applied_list_modified(self):
876*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
877*da0073e9SAndroid Build Coastguard Worker        def foo():
878*da0073e9SAndroid Build Coastguard Worker            li = [1, 2, 3, 4, 5, 6, 7]
879*da0073e9SAndroid Build Coastguard Worker            li[0] = 0
880*da0073e9SAndroid Build Coastguard Worker            return li[2:5]
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
883*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::slice").run(foo.graph)
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker    def test_peephole_slice_optimization_not_applied_non_const_args(self):
886*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
887*da0073e9SAndroid Build Coastguard Worker        def foo(x: int, y: int):
888*da0073e9SAndroid Build Coastguard Worker            li = [1, 2, 3, 4, 5, 6, 7]
889*da0073e9SAndroid Build Coastguard Worker            return li[x:y]
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker        self.run_pass("peephole", foo.graph)
892*da0073e9SAndroid Build Coastguard Worker        FileCheck().check("aten::slice").run(foo.graph)
893