1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch import nn 8*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck 9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 13*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 14*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 15*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 16*da0073e9SAndroid Build Coastguard Worker "instead." 17*da0073e9SAndroid Build Coastguard Worker ) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerclass TestPeephole(JitTestCase): 21*da0073e9SAndroid Build Coastguard Worker def test_peephole_with_writes(self): 22*da0073e9SAndroid Build Coastguard Worker def test_write(x): 23*da0073e9SAndroid Build Coastguard Worker s = 0 24*da0073e9SAndroid Build Coastguard Worker s += x 25*da0073e9SAndroid Build Coastguard Worker s += x 26*da0073e9SAndroid Build Coastguard Worker return s 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_write, (torch.ones(4, 4),)) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker def test_peephole_with_non_output_writes(self): 31*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 32*da0073e9SAndroid Build Coastguard Worker def nomnom(x): 33*da0073e9SAndroid Build Coastguard Worker pass 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def test_write(x): 36*da0073e9SAndroid Build Coastguard Worker t = torch.ones_like(x) 37*da0073e9SAndroid Build Coastguard Worker z = x.clone() 38*da0073e9SAndroid Build Coastguard Worker y = z + 0 39*da0073e9SAndroid Build Coastguard Worker z.add_(t) 40*da0073e9SAndroid Build Coastguard Worker # this makes sure z isn't blasted out of existence 41*da0073e9SAndroid Build Coastguard Worker # because it isn't returned or used in a side-effectful 42*da0073e9SAndroid Build Coastguard Worker # way 43*da0073e9SAndroid Build Coastguard Worker nomnom(z) 44*da0073e9SAndroid Build Coastguard Worker return y + y 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker a = torch.ones(4, 4) 47*da0073e9SAndroid Build Coastguard Worker j = self.checkScript(test_write, (a,)) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def test_peephole_no_output_aliasing(self): 50*da0073e9SAndroid Build Coastguard Worker def test_peephole(x): 51*da0073e9SAndroid Build Coastguard Worker y = x + 0 52*da0073e9SAndroid Build Coastguard Worker return x, y 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker a = torch.ones(4, 4) 55*da0073e9SAndroid Build Coastguard Worker j = self.checkScript(test_peephole, (a,)) 56*da0073e9SAndroid Build Coastguard Worker r1, r2 = j(a) 57*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(r1.data_ptr(), r2.data_ptr()) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def test_peephole(self): 60*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([0.4]) 61*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([0.7]) 62*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([0], dtype=torch.int32) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def f(x, y): 65*da0073e9SAndroid Build Coastguard Worker return x.type_as(y) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker tf = torch.jit.trace(f, (a, b)) 68*da0073e9SAndroid Build Coastguard Worker FileCheck().check("type_as").run(str(tf.graph)) 69*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", tf.graph) 70*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("type_as").run(str(tf.graph)) 71*da0073e9SAndroid Build Coastguard Worker tf2 = torch.jit.trace(f, (a, c)) 72*da0073e9SAndroid Build Coastguard Worker s = str(tf2.graph) 73*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", tf2.graph) 74*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s, str(s)) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def test_peephole_dynamic(self): 77*da0073e9SAndroid Build Coastguard Worker def f(x, y): 78*da0073e9SAndroid Build Coastguard Worker return x.type_as(y) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker fn = torch.jit.script(f) 81*da0073e9SAndroid Build Coastguard Worker s = str(fn.graph) 82*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole(fn.graph) 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s, str(fn.graph)) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker def test_peephole_list_ops(self): 86*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 87*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 88*da0073e9SAndroid Build Coastguard Worker return len([x, y, z]) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 91*da0073e9SAndroid Build Coastguard Worker FileCheck().check("value=3").check_next("return").run(foo.graph) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 94*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 95*da0073e9SAndroid Build Coastguard Worker li = [x, y, z] 96*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 97*da0073e9SAndroid Build Coastguard Worker li.append(x) 98*da0073e9SAndroid Build Coastguard Worker return len([x, y, z]) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 101*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::len").run(foo.graph) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 104*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 105*da0073e9SAndroid Build Coastguard Worker li = [x, y, z] 106*da0073e9SAndroid Build Coastguard Worker return li[1], li[-2] 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::__getitem__").run(foo.graph) 109*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 110*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::__getitem__").run(foo.graph) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 113*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 114*da0073e9SAndroid Build Coastguard Worker li = [x, y, z] 115*da0073e9SAndroid Build Coastguard Worker return li[-7] 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 118*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::__getitem__").run(foo.graph) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 121*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 122*da0073e9SAndroid Build Coastguard Worker li = [x, y, z] 123*da0073e9SAndroid Build Coastguard Worker for i in range(len(x)): 124*da0073e9SAndroid Build Coastguard Worker li.append(x) 125*da0073e9SAndroid Build Coastguard Worker return li[-2] 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 128*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::__getitem__").run(foo.graph) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") 131*da0073e9SAndroid Build Coastguard Worker def test_peephole_cuda(self): 132*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([0.4], device="cpu") 133*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([0.7], device="cuda") 134*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([0.7], device="cuda") 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker def f(x, y): 137*da0073e9SAndroid Build Coastguard Worker return x.type_as(y) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker trace = torch.jit.trace(f, (a, c)) 140*da0073e9SAndroid Build Coastguard Worker s = str(trace.graph) 141*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", trace.graph) 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s, str(trace.graph)) 143*da0073e9SAndroid Build Coastguard Worker trace = torch.jit.trace(f, (b, c)) 144*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", trace.graph) 145*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", trace.graph) 146*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("type_as").run(str(trace.graph)) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker @_inline_everything 149*da0073e9SAndroid Build Coastguard Worker def test_peephole_type_refinements(self): 150*da0073e9SAndroid Build Coastguard Worker def refine(x): 151*da0073e9SAndroid Build Coastguard Worker # type: (Optional[Tensor]) -> Tensor 152*da0073e9SAndroid Build Coastguard Worker return x if x is not None else torch.tensor(3) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 155*da0073e9SAndroid Build Coastguard Worker def test(): 156*da0073e9SAndroid Build Coastguard Worker return refine(torch.tensor(4)) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::unchecked_cast").run(test.graph) 159*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", test.graph) 160*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::unchecked_cast").run(test.graph) 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker # refinement not optimzied out 163*da0073e9SAndroid Build Coastguard Worker def is_int_tensor(x): 164*da0073e9SAndroid Build Coastguard Worker scalar = x.item() 165*da0073e9SAndroid Build Coastguard Worker if isinstance(scalar, int): 166*da0073e9SAndroid Build Coastguard Worker return scalar + 3 167*da0073e9SAndroid Build Coastguard Worker else: 168*da0073e9SAndroid Build Coastguard Worker return 8 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker self.checkScript(is_int_tensor, (torch.tensor(2),)) 171*da0073e9SAndroid Build Coastguard Worker self.checkScript(is_int_tensor, (torch.tensor(2.5),)) 172*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(is_int_tensor).graph 173*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", graph) 174*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::unchecked_cast").run(graph) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker def test_short_circuit_optimization(self): 177*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 178*da0073e9SAndroid Build Coastguard Worker def const_expressions(x): 179*da0073e9SAndroid Build Coastguard Worker # type: (int) -> Tuple[bool, bool] 180*da0073e9SAndroid Build Coastguard Worker return x == 1 and False, x == 1 or True 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", const_expressions.graph) 183*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::If").check_not("aten::eq").run( 184*da0073e9SAndroid Build Coastguard Worker const_expressions.graph 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(const_expressions(1), (False, True)) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 189*da0073e9SAndroid Build Coastguard Worker def redundant_expressions(x): 190*da0073e9SAndroid Build Coastguard Worker # type: (int) -> Tuple[bool, bool] 191*da0073e9SAndroid Build Coastguard Worker return x == 1 and True, x == 1 or False 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", redundant_expressions.graph) 194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(redundant_expressions(1), (True, True)) 195*da0073e9SAndroid Build Coastguard Worker self.assertEqual(redundant_expressions(0), (False, False)) 196*da0073e9SAndroid Build Coastguard Worker # and True / or False are removed from graph 197*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::eq").check_not("prim::If").run( 198*da0073e9SAndroid Build Coastguard Worker redundant_expressions.graph 199*da0073e9SAndroid Build Coastguard Worker ) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker def test_conv_dim_folding(self): 202*da0073e9SAndroid Build Coastguard Worker modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] 203*da0073e9SAndroid Build Coastguard Worker for mod in modules: 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker class ConvDim(torch.nn.Module): 206*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 207*da0073e9SAndroid Build Coastguard Worker super().__init__() 208*da0073e9SAndroid Build Coastguard Worker self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 211*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 212*da0073e9SAndroid Build Coastguard Worker return x.dim() 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker conv_dim = torch.jit.script(ConvDim()) 215*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", conv_dim.graph) 216*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", conv_dim.graph) 217*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("conv").check_not("dim").run(conv_dim.graph) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker class ConvDimMutate(torch.nn.Module): 220*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 221*da0073e9SAndroid Build Coastguard Worker super().__init__() 222*da0073e9SAndroid Build Coastguard Worker self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 225*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 226*da0073e9SAndroid Build Coastguard Worker x.resize_([4, 4]) 227*da0073e9SAndroid Build Coastguard Worker return x.dim() 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker conv_dim = torch.jit.script(ConvDimMutate()) 230*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", conv_dim.graph) 231*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", conv_dim.graph) 232*da0073e9SAndroid Build Coastguard Worker FileCheck().check("conv").check("dim").run(conv_dim.graph) 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker def test_normalized_rsub(self): 235*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1, 2, 3]) 236*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([4, 5, 6]) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def convertible_rsub(x, y): 239*da0073e9SAndroid Build Coastguard Worker return (x - y), torch.rsub(y, x) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker self.checkScript(convertible_rsub, (a, b)) 242*da0073e9SAndroid Build Coastguard Worker op_graph = torch.jit.script(convertible_rsub).graph 243*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::sub", 2, exactly=True).run(op_graph) 244*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::rsub", 0, exactly=True).run(op_graph) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker def test_normalized_is_op(self): 247*da0073e9SAndroid Build Coastguard Worker def convertible_is_op(x: bool, y: bool): 248*da0073e9SAndroid Build Coastguard Worker return x is True, False is x, x is y 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker self.checkScript(convertible_is_op, (True, False)) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker op_graph = torch.jit.script(convertible_is_op).graph 253*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::eq", 3, exactly=True).run(op_graph) 254*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::__is__", 0, exactly=True).run(op_graph) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker def test_normalized_isnot_op(self): 257*da0073e9SAndroid Build Coastguard Worker def convertible_isnot_op(x: bool, y: bool): 258*da0073e9SAndroid Build Coastguard Worker return x is not True, False is not x, x is not y 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker self.checkScript(convertible_isnot_op, (True, False)) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker op_graph = torch.jit.script(convertible_isnot_op).graph 263*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::ne", 3, exactly=True).run(op_graph) 264*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::__isnot__", 0, exactly=True).run(op_graph) 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker def test_peephole_list_len(self): 267*da0073e9SAndroid Build Coastguard Worker def run_peephole_and_check_const_value(graph, const_string): 268*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(graph, refine_list_len=True) 269*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", graph) 270*da0073e9SAndroid Build Coastguard Worker FileCheck().check(const_string).check_next("return").run(graph) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker def gen_li(inp_len: int): 273*da0073e9SAndroid Build Coastguard Worker return [0 for i in range(inp_len)] 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 276*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], y: List[int]): 277*da0073e9SAndroid Build Coastguard Worker if len(x) != 4 or len(y) != 5: 278*da0073e9SAndroid Build Coastguard Worker raise Exception("") # noqa: TRY002 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker return len(x) + len(y) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker run_peephole_and_check_const_value(foo.graph, "value=9") 283*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(gen_li(4), gen_li(5)), 9) 284*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 285*da0073e9SAndroid Build Coastguard Worker foo(2, 4) 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 288*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], y: List[int]): 289*da0073e9SAndroid Build Coastguard Worker if len(x) == 4 and len(y) == 5: 290*da0073e9SAndroid Build Coastguard Worker pass 291*da0073e9SAndroid Build Coastguard Worker else: 292*da0073e9SAndroid Build Coastguard Worker raise Exception("hi") # noqa: TRY002 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker return len(x) + len(y) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker run_peephole_and_check_const_value(foo.graph, "value=9") 297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(gen_li(4), gen_li(5)), 9) 298*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 299*da0073e9SAndroid Build Coastguard Worker foo(2, 4) 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 302*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], y: List[int], z: List[int]): 303*da0073e9SAndroid Build Coastguard Worker if len(x) != 4: 304*da0073e9SAndroid Build Coastguard Worker raise Exception("..") # noqa: TRY002 305*da0073e9SAndroid Build Coastguard Worker else: 306*da0073e9SAndroid Build Coastguard Worker if len(y) != 8: 307*da0073e9SAndroid Build Coastguard Worker raise Exception("...") # noqa: TRY002 308*da0073e9SAndroid Build Coastguard Worker else: 309*da0073e9SAndroid Build Coastguard Worker if len(z) == 3: 310*da0073e9SAndroid Build Coastguard Worker pass 311*da0073e9SAndroid Build Coastguard Worker else: 312*da0073e9SAndroid Build Coastguard Worker raise Exception("...") # noqa: TRY002 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker return len(x) + len(y) * len(z) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker run_peephole_and_check_const_value(foo.graph, "value=28") 317*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(gen_li(4), gen_li(8), gen_li(3)), 28) 318*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 319*da0073e9SAndroid Build Coastguard Worker foo(1, 2, 3) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker # refinement should persist in second len(x) call 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 324*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], cond: bool): 325*da0073e9SAndroid Build Coastguard Worker if len(x) == 4: 326*da0073e9SAndroid Build Coastguard Worker if cond: 327*da0073e9SAndroid Build Coastguard Worker return len(x) 328*da0073e9SAndroid Build Coastguard Worker return 4 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker return 4 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker run_peephole_and_check_const_value(foo.graph, "value=4") 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker def test_const_tuple_output(graph, const_inputs): 335*da0073e9SAndroid Build Coastguard Worker tup = graph.findNode("prim::TupleConstruct") 336*da0073e9SAndroid Build Coastguard Worker for i, elem in enumerate(tup.inputs()): 337*da0073e9SAndroid Build Coastguard Worker if i in const_inputs: 338*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(elem.toIValue()) 339*da0073e9SAndroid Build Coastguard Worker else: 340*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(elem.toIValue()) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # testing combinations of x1 : {True, False} x 343*da0073e9SAndroid Build Coastguard Worker # {then/else branch} x assert {True/False} 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 346*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], b: List[int]): 347*da0073e9SAndroid Build Coastguard Worker if len(x) == 5: 348*da0073e9SAndroid Build Coastguard Worker x1 = True 349*da0073e9SAndroid Build Coastguard Worker else: 350*da0073e9SAndroid Build Coastguard Worker x1 = len(b) != 4 351*da0073e9SAndroid Build Coastguard Worker assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 352*da0073e9SAndroid Build Coastguard Worker return len(x), len(b) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 355*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_constant_propagation(foo.graph) 356*da0073e9SAndroid Build Coastguard Worker # we can only infer len(b) == 4 here 357*da0073e9SAndroid Build Coastguard Worker test_const_tuple_output(foo.graph, [1]) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 360*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], b: List[int]): 361*da0073e9SAndroid Build Coastguard Worker if len(x) == 5: 362*da0073e9SAndroid Build Coastguard Worker x1 = False 363*da0073e9SAndroid Build Coastguard Worker else: 364*da0073e9SAndroid Build Coastguard Worker x1 = len(b) != 4 365*da0073e9SAndroid Build Coastguard Worker assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 366*da0073e9SAndroid Build Coastguard Worker return len(x), len(b) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 369*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_constant_propagation(foo.graph) 370*da0073e9SAndroid Build Coastguard Worker # cant infer anything 371*da0073e9SAndroid Build Coastguard Worker test_const_tuple_output(foo.graph, []) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 374*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], b: List[int]): 375*da0073e9SAndroid Build Coastguard Worker if len(x) == 5: 376*da0073e9SAndroid Build Coastguard Worker x1 = True 377*da0073e9SAndroid Build Coastguard Worker else: 378*da0073e9SAndroid Build Coastguard Worker x1 = len(b) == 4 379*da0073e9SAndroid Build Coastguard Worker assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 380*da0073e9SAndroid Build Coastguard Worker return len(x), len(b) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 383*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_constant_propagation(foo.graph) 384*da0073e9SAndroid Build Coastguard Worker # we cant infer anything, only len(b) != 4 385*da0073e9SAndroid Build Coastguard Worker test_const_tuple_output(foo.graph, []) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 388*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], b: List[int]): 389*da0073e9SAndroid Build Coastguard Worker if len(x) == 5: 390*da0073e9SAndroid Build Coastguard Worker x1 = True 391*da0073e9SAndroid Build Coastguard Worker else: 392*da0073e9SAndroid Build Coastguard Worker x1 = len(b) != 4 393*da0073e9SAndroid Build Coastguard Worker assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 394*da0073e9SAndroid Build Coastguard Worker return len(x), len(b) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 397*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_constant_propagation(foo.graph) 398*da0073e9SAndroid Build Coastguard Worker # can infer len(b) == 4 399*da0073e9SAndroid Build Coastguard Worker test_const_tuple_output(foo.graph, [1]) 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker # swap branches 402*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 403*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], b: List[int]): 404*da0073e9SAndroid Build Coastguard Worker if len(x) != 5: 405*da0073e9SAndroid Build Coastguard Worker x1 = len(b) != 4 406*da0073e9SAndroid Build Coastguard Worker else: 407*da0073e9SAndroid Build Coastguard Worker x1 = True 408*da0073e9SAndroid Build Coastguard Worker assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 409*da0073e9SAndroid Build Coastguard Worker return len(x), len(b) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 412*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_constant_propagation(foo.graph) 413*da0073e9SAndroid Build Coastguard Worker # can infer len(b) == 4 414*da0073e9SAndroid Build Coastguard Worker test_const_tuple_output(foo.graph, [1]) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker # use __not__ 417*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 418*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], b: List[int]): 419*da0073e9SAndroid Build Coastguard Worker if len(x) != 5: 420*da0073e9SAndroid Build Coastguard Worker x1 = len(b) != 4 421*da0073e9SAndroid Build Coastguard Worker else: 422*da0073e9SAndroid Build Coastguard Worker x1 = True 423*da0073e9SAndroid Build Coastguard Worker assert not x1 424*da0073e9SAndroid Build Coastguard Worker return len(x), len(b) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 427*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_constant_propagation(foo.graph) 428*da0073e9SAndroid Build Coastguard Worker # can infer len(b) == 4 429*da0073e9SAndroid Build Coastguard Worker test_const_tuple_output(foo.graph, [1]) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker # Test unsuccessful optimizations 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 434*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int]): 435*da0073e9SAndroid Build Coastguard Worker assert len(x) == 4 436*da0073e9SAndroid Build Coastguard Worker x.append(3) 437*da0073e9SAndroid Build Coastguard Worker return len(x) 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 440*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", foo.graph) 441*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::len", 2).run(foo.graph) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 444*da0073e9SAndroid Build Coastguard Worker def foo(x: List[int], y: List[int]): 445*da0073e9SAndroid Build Coastguard Worker assert len(x) == 4 or len(y) == 5 446*da0073e9SAndroid Build Coastguard Worker return len(x) + len(y) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 449*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", foo.graph) 450*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::len", 4).run(foo.graph) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker def test_integer_refinement(self): 453*da0073e9SAndroid Build Coastguard Worker def run_peephole_and_check_const_value(graph, const_string): 454*da0073e9SAndroid Build Coastguard Worker self.run_pass("refine_integer_values", graph) 455*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", graph) 456*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", graph) 457*da0073e9SAndroid Build Coastguard Worker FileCheck().check(const_string).check_next("return").run(graph) 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 460*da0073e9SAndroid Build Coastguard Worker def foo(x: int, y: int): 461*da0073e9SAndroid Build Coastguard Worker if x != 4 or y != 5: 462*da0073e9SAndroid Build Coastguard Worker raise Exception("") # noqa: TRY002 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker return x + y 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker graph = foo.graph 467*da0073e9SAndroid Build Coastguard Worker self.run_pass("refine_integer_values", graph) 468*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", graph) 469*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", graph) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker run_peephole_and_check_const_value(foo.graph, "value=9") 472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(4, 5), 9) 473*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 474*da0073e9SAndroid Build Coastguard Worker foo(2, 4) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 477*da0073e9SAndroid Build Coastguard Worker def foo(x: int, y: int): 478*da0073e9SAndroid Build Coastguard Worker if x == 4 and y == 5: 479*da0073e9SAndroid Build Coastguard Worker pass 480*da0073e9SAndroid Build Coastguard Worker else: 481*da0073e9SAndroid Build Coastguard Worker raise Exception("hi") # noqa: TRY002 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker return x + y 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker run_peephole_and_check_const_value(foo.graph, "value=9") 486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(4, 5), 9) 487*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 488*da0073e9SAndroid Build Coastguard Worker foo(2, 4) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 491*da0073e9SAndroid Build Coastguard Worker def foo(x: int, y: int, z: int): 492*da0073e9SAndroid Build Coastguard Worker if x != 4: 493*da0073e9SAndroid Build Coastguard Worker raise Exception("..") # noqa: TRY002 494*da0073e9SAndroid Build Coastguard Worker else: 495*da0073e9SAndroid Build Coastguard Worker if y != 8: 496*da0073e9SAndroid Build Coastguard Worker raise Exception("...") # noqa: TRY002 497*da0073e9SAndroid Build Coastguard Worker else: 498*da0073e9SAndroid Build Coastguard Worker if z == 3: 499*da0073e9SAndroid Build Coastguard Worker pass 500*da0073e9SAndroid Build Coastguard Worker else: 501*da0073e9SAndroid Build Coastguard Worker raise Exception("...") # noqa: TRY002 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker return x + y * z 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker run_peephole_and_check_const_value(foo.graph, "value=28") 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(4, 8, 3), 28) 507*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 508*da0073e9SAndroid Build Coastguard Worker foo(1, 2, 3) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker # refinement should persist in second len(x) call 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 513*da0073e9SAndroid Build Coastguard Worker def foo(x: int, cond: bool): 514*da0073e9SAndroid Build Coastguard Worker if x == 4: 515*da0073e9SAndroid Build Coastguard Worker if cond: 516*da0073e9SAndroid Build Coastguard Worker return x 517*da0073e9SAndroid Build Coastguard Worker return 4 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker return 4 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker run_peephole_and_check_const_value(foo.graph, "value=4") 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 524*da0073e9SAndroid Build Coastguard Worker def foo(x: int, y: int): 525*da0073e9SAndroid Build Coastguard Worker assert x == 4 or y == 5 526*da0073e9SAndroid Build Coastguard Worker return x + y 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 529*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", foo.graph) 530*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::add").run(foo.graph) 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker def test_optimize_out_comparison_same_value(self): 533*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 534*da0073e9SAndroid Build Coastguard Worker return x == x, x != x 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker def foo2(x: List[int]): 537*da0073e9SAndroid Build Coastguard Worker return x == x, x != x 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker for func, inp in zip([foo, foo2], [1, [2, 3]]): 540*da0073e9SAndroid Build Coastguard Worker func_s = torch.jit.script(func) 541*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", func_s.graph) 542*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::eq").check_not("aten::neq").run(func_s.graph) 543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(inp), func_s(inp)) 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker def test_peephole_add_zero(self): 546*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 547*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 548*da0073e9SAndroid Build Coastguard Worker return x + 0, 0 + x 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 551*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::add") 552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(3), (3, 3)) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker def test_noop_peephole(self): 555*da0073e9SAndroid Build Coastguard Worker # test unsuccessful 556*da0073e9SAndroid Build Coastguard Worker def foo1(x): 557*da0073e9SAndroid Build Coastguard Worker return x + 0 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker def foo2(): 560*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([2, 2]) 561*da0073e9SAndroid Build Coastguard Worker x.sub_(3) 562*da0073e9SAndroid Build Coastguard Worker return x + 0 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker def foo3(): 565*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([2, 2]) 566*da0073e9SAndroid Build Coastguard Worker return x, x + 0 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker def foo4(): 569*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([2, 2]) 570*da0073e9SAndroid Build Coastguard Worker return x + 0.0 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker funcs = foo1, foo2, foo3, foo4 573*da0073e9SAndroid Build Coastguard Worker inps = (torch.ones([2]),), (), (), () 574*da0073e9SAndroid Build Coastguard Worker for func, inp in zip(funcs, inps): 575*da0073e9SAndroid Build Coastguard Worker foo_s = torch.jit.script(func) 576*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo_s.graph) 577*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::add", 1, exactly=True).run(foo_s.graph) 578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(*inp), foo_s(*inp)) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker # successful 581*da0073e9SAndroid Build Coastguard Worker def func(x): 582*da0073e9SAndroid Build Coastguard Worker return (x + 0) * 1 - 5 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker func_s = torch.jit.script(func) 585*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", func_s.graph) 586*da0073e9SAndroid Build Coastguard Worker # bail on modified value first 587*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::add").check("aten::mul").run(func_s.graph) 588*da0073e9SAndroid Build Coastguard Worker # second run it should succeed 589*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", func_s.graph) 590*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::add").check_not("aten::mul").run(func_s.graph) 591*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(torch.ones([2, 2])), func_s(torch.ones([2, 2]))) 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker def func(x): 594*da0073e9SAndroid Build Coastguard Worker return (x + 0.0) - 5 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker func_s = torch.jit.script(func) 597*da0073e9SAndroid Build Coastguard Worker inp = next(func_s.graph.inputs()) 598*da0073e9SAndroid Build Coastguard Worker inp.setType(torch._C.TensorType.create_from_tensor(torch.rand([2, 2]))) 599*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole(func_s.graph, disable_shape_peepholes=True) 600*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::add").run(func_s.graph) 601*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_peephole(func_s.graph, disable_shape_peepholes=False) 602*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::add").run(func_s.graph) 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker def test_refine_integer_values(self): 605*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 606*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 607*da0073e9SAndroid Build Coastguard Worker y = 1 608*da0073e9SAndroid Build Coastguard Worker if x == 1: 609*da0073e9SAndroid Build Coastguard Worker return y 610*da0073e9SAndroid Build Coastguard Worker else: 611*da0073e9SAndroid Build Coastguard Worker return x 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker self.run_pass("refine_integer_values", foo.graph) 614*da0073e9SAndroid Build Coastguard Worker self.run_pass("constant_propagation", foo.graph) 615*da0073e9SAndroid Build Coastguard Worker self.run_pass("dce", foo.graph) 616*da0073e9SAndroid Build Coastguard Worker FileCheck().check("graph").check_next("return").run(foo.graph) 617*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(2), 2) 618*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(1), 1) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker def test_peephole_len_list(self): 621*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 622*da0073e9SAndroid Build Coastguard Worker def foo(x): 623*da0073e9SAndroid Build Coastguard Worker return len(x.size()) 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 626*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::len").run(foo.graph) 627*da0073e9SAndroid Build Coastguard Worker inputs = list(foo.graph.inputs()) 628*da0073e9SAndroid Build Coastguard Worker inputs[0].setType(inputs[0].type().with_sizes([None, None])) 629*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 630*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::len").run(foo.graph) 631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, foo(torch.rand([3, 1]))) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 634*da0073e9SAndroid Build Coastguard Worker def foo(x): 635*da0073e9SAndroid Build Coastguard Worker li = x.size() 636*da0073e9SAndroid Build Coastguard Worker li.append(4) 637*da0073e9SAndroid Build Coastguard Worker return len(li) 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker inputs = list(foo.graph.inputs()) 640*da0073e9SAndroid Build Coastguard Worker inputs[0].setType(inputs[0].type().with_sizes([None, None])) 641*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 642*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::len").run(foo.graph) 643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3, foo(torch.rand([3, 1]))) 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker def test_peephole_optional_refine(self): 646*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 647*da0073e9SAndroid Build Coastguard Worker def foo(z: int, z2: int, cond: bool): 648*da0073e9SAndroid Build Coastguard Worker if cond: 649*da0073e9SAndroid Build Coastguard Worker return z 650*da0073e9SAndroid Build Coastguard Worker else: 651*da0073e9SAndroid Build Coastguard Worker return z2 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker out = next(foo.graph.findNode("prim::If").outputs()) 654*da0073e9SAndroid Build Coastguard Worker out.setType(torch._C.OptionalType(torch._C.IntType.get())) 655*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 656*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("int?").run(foo.graph) 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker def test_peephole_int(self): 659*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 660*da0073e9SAndroid Build Coastguard Worker def foo(x): 661*da0073e9SAndroid Build Coastguard Worker # type: (number) 662*da0073e9SAndroid Build Coastguard Worker return int(x) 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::Int").run(foo.graph) 665*da0073e9SAndroid Build Coastguard Worker next(foo.graph.inputs()).setType(torch._C.IntType.get()) 666*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 667*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::Int").run(foo.graph) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker def test_peephole_arith(self): 670*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 671*da0073e9SAndroid Build Coastguard Worker def foo(input0: int, input1: int, input2: int, input3: int): 672*da0073e9SAndroid Build Coastguard Worker _1 = torch.add(input1, 2) 673*da0073e9SAndroid Build Coastguard Worker _3 = torch.add(input3, 2) 674*da0073e9SAndroid Build Coastguard Worker _5 = torch.add(1, torch.sub(_1, 3) // 1) 675*da0073e9SAndroid Build Coastguard Worker _6 = torch.add(1 * torch.sub(_3, 3) // 1, 1) / 1 676*da0073e9SAndroid Build Coastguard Worker return [_5, int(_6)] 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::add").check("aten::sub").check("aten::mul").check( 679*da0073e9SAndroid Build Coastguard Worker "aten::floordiv" 680*da0073e9SAndroid Build Coastguard Worker ).check("aten::div").run(foo.graph) 681*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 682*da0073e9SAndroid Build Coastguard Worker FileCheck().check("graph").check("):").check_next("ListConstruct").check_next( 683*da0073e9SAndroid Build Coastguard Worker "return" 684*da0073e9SAndroid Build Coastguard Worker ).run(foo.graph) 685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(0, 1, 2, 3), [1, 3]) 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_getitem_simple(self): 688*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 689*da0073e9SAndroid Build Coastguard Worker def foo(a: int, b: int): 690*da0073e9SAndroid Build Coastguard Worker d = {0: a, 1: b} 691*da0073e9SAndroid Build Coastguard Worker x = d[1] 692*da0073e9SAndroid Build Coastguard Worker y = d[0] 693*da0073e9SAndroid Build Coastguard Worker return x, y 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 696*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) 697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(0, 1), (1, 0)) 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 700*da0073e9SAndroid Build Coastguard Worker def foo(a: int, b: int): 701*da0073e9SAndroid Build Coastguard Worker d = {"0": a, "1": b} 702*da0073e9SAndroid Build Coastguard Worker x = d["1"] 703*da0073e9SAndroid Build Coastguard Worker y = d["0"] 704*da0073e9SAndroid Build Coastguard Worker return x, y 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 707*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) 708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(0, 1), (1, 0)) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 711*da0073e9SAndroid Build Coastguard Worker def foo(a: int, b: int): 712*da0073e9SAndroid Build Coastguard Worker d = {0.0: a, 1.0: b} 713*da0073e9SAndroid Build Coastguard Worker x = d[1.0] 714*da0073e9SAndroid Build Coastguard Worker y = d[0.0] 715*da0073e9SAndroid Build Coastguard Worker return x, y 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 718*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) 719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(0, 1), (1, 0)) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_getitem_no_optimization_missing_key(self): 722*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 723*da0073e9SAndroid Build Coastguard Worker def foo(): 724*da0073e9SAndroid Build Coastguard Worker d = {0: 1} 725*da0073e9SAndroid Build Coastguard Worker return d[2] 726*da0073e9SAndroid Build Coastguard Worker 727*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 728*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_getitem_no_optimization_get_input_arg(self): 731*da0073e9SAndroid Build Coastguard Worker # Here we don't know if the input arg is in the dict, so we can't 732*da0073e9SAndroid Build Coastguard Worker # make the optimization. 733*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 734*da0073e9SAndroid Build Coastguard Worker def foo(a: int): 735*da0073e9SAndroid Build Coastguard Worker d = {0: 1} 736*da0073e9SAndroid Build Coastguard Worker return d[a] 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 739*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 740*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(0), 1) 741*da0073e9SAndroid Build Coastguard Worker 742*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_getitem_no_optimization_dict_modified(self): 743*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 744*da0073e9SAndroid Build Coastguard Worker def foo(): 745*da0073e9SAndroid Build Coastguard Worker d = {0: 1} 746*da0073e9SAndroid Build Coastguard Worker d[0] = 2 747*da0073e9SAndroid Build Coastguard Worker return d[0] 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 750*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 751*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 2) 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_getitem_no_optimization_overlapping_keys(self): 754*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 755*da0073e9SAndroid Build Coastguard Worker def foo(): 756*da0073e9SAndroid Build Coastguard Worker d = {0: 1, 0: 2} # noqa: F601 757*da0073e9SAndroid Build Coastguard Worker return d[0] 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 760*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_getitem_no_optimization_keys_might_overlap(self): 763*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 764*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 765*da0073e9SAndroid Build Coastguard Worker d = {0: 1, x: 2} 766*da0073e9SAndroid Build Coastguard Worker return d[x] 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 769*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_getitem_no_optimization_unsupported_type(self): 772*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 773*da0073e9SAndroid Build Coastguard Worker def foo(): 774*da0073e9SAndroid Build Coastguard Worker a = torch.rand((2, 2)) 775*da0073e9SAndroid Build Coastguard Worker d = {a: 1} 776*da0073e9SAndroid Build Coastguard Worker return d[a] 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 779*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 780*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 1) 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_len(self): 783*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 784*da0073e9SAndroid Build Coastguard Worker def foo(): 785*da0073e9SAndroid Build Coastguard Worker d = {0: 1, 1: 2} 786*da0073e9SAndroid Build Coastguard Worker return len(d) 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 789*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("DictConstruct").check_not("len").run(foo.graph) 790*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 2) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_len_no_optimization_overlapping_keys(self): 793*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 794*da0073e9SAndroid Build Coastguard Worker def foo(): 795*da0073e9SAndroid Build Coastguard Worker d = {0: 1, 0: 2} # noqa: F601 796*da0073e9SAndroid Build Coastguard Worker return len(d) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 799*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("len").run(foo.graph) 800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 1) 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_len_no_optimization_keys_might_overlap(self): 803*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 804*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 805*da0073e9SAndroid Build Coastguard Worker d = {0: 1, x: 2} 806*da0073e9SAndroid Build Coastguard Worker return len(d) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 809*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("len").run(foo.graph) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker def test_peephole_dict_len_no_optimization_unsupported_type(self): 812*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 813*da0073e9SAndroid Build Coastguard Worker def foo(): 814*da0073e9SAndroid Build Coastguard Worker a = torch.rand((2, 2)) 815*da0073e9SAndroid Build Coastguard Worker d = {a: 1} 816*da0073e9SAndroid Build Coastguard Worker return len(d) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 819*da0073e9SAndroid Build Coastguard Worker FileCheck().check("DictConstruct").check("len").run(foo.graph) 820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 1) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker def test_peephole_slice_all_three_args(self): 823*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 824*da0073e9SAndroid Build Coastguard Worker return [1, 2, x, 4, 5, 6, 7][-5:6:2] 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(foo).graph 827*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", graph) 828*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::slice").run(graph) 829*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (3,)) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker def test_peephole_slice_one_empty_arg(self): 832*da0073e9SAndroid Build Coastguard Worker def check_helper(fn: Callable[[int], None]) -> None: 833*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(fn).graph 834*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", graph) 835*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::slice").run(graph) 836*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (3,)) 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 839*da0073e9SAndroid Build Coastguard Worker return [1, 2, x, 4, 5, 6, 7][1::2] 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker check_helper(foo) 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 844*da0073e9SAndroid Build Coastguard Worker return [1, 2, x, 4, 5, 6, 7][:5:3] 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker check_helper(foo) 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 849*da0073e9SAndroid Build Coastguard Worker return [1, 2, x, 4, 5, 6, 7][0:4] 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker check_helper(foo) 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker def test_peephole_slice_two_empty_args(self): 854*da0073e9SAndroid Build Coastguard Worker def check_helper(fn: Callable[[int], None]) -> None: 855*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(fn).graph 856*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", graph) 857*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::slice").run(graph) 858*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (3,)) 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 861*da0073e9SAndroid Build Coastguard Worker return [1, 2, x, 4, 5, 6, 7][::2] 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker check_helper(foo) 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 866*da0073e9SAndroid Build Coastguard Worker return [1, 2, x, 4, 5, 6, 7][:5] 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker check_helper(foo) 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 871*da0073e9SAndroid Build Coastguard Worker return [1, 2, x, 4, 5, 6, 7][1:] 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker check_helper(foo) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker def test_peephole_slice_optimization_not_applied_list_modified(self): 876*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 877*da0073e9SAndroid Build Coastguard Worker def foo(): 878*da0073e9SAndroid Build Coastguard Worker li = [1, 2, 3, 4, 5, 6, 7] 879*da0073e9SAndroid Build Coastguard Worker li[0] = 0 880*da0073e9SAndroid Build Coastguard Worker return li[2:5] 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 883*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::slice").run(foo.graph) 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker def test_peephole_slice_optimization_not_applied_non_const_args(self): 886*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 887*da0073e9SAndroid Build Coastguard Worker def foo(x: int, y: int): 888*da0073e9SAndroid Build Coastguard Worker li = [1, 2, 3, 4, 5, 6, 7] 889*da0073e9SAndroid Build Coastguard Worker return li[x:y] 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker self.run_pass("peephole", foo.graph) 892*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::slice").run(foo.graph) 893