1# Owner(s): ["oncall: jit"] 2 3import unittest 4from typing import Callable, List 5 6import torch 7from torch import nn 8from torch.testing import FileCheck 9from torch.testing._internal.jit_utils import _inline_everything, JitTestCase, RUN_CUDA 10 11 12if __name__ == "__main__": 13 raise RuntimeError( 14 "This test file is not meant to be run directly, use:\n\n" 15 "\tpython test/test_jit.py TESTNAME\n\n" 16 "instead." 17 ) 18 19 20class TestPeephole(JitTestCase): 21 def test_peephole_with_writes(self): 22 def test_write(x): 23 s = 0 24 s += x 25 s += x 26 return s 27 28 self.checkScript(test_write, (torch.ones(4, 4),)) 29 30 def test_peephole_with_non_output_writes(self): 31 @torch.jit.ignore 32 def nomnom(x): 33 pass 34 35 def test_write(x): 36 t = torch.ones_like(x) 37 z = x.clone() 38 y = z + 0 39 z.add_(t) 40 # this makes sure z isn't blasted out of existence 41 # because it isn't returned or used in a side-effectful 42 # way 43 nomnom(z) 44 return y + y 45 46 a = torch.ones(4, 4) 47 j = self.checkScript(test_write, (a,)) 48 49 def test_peephole_no_output_aliasing(self): 50 def test_peephole(x): 51 y = x + 0 52 return x, y 53 54 a = torch.ones(4, 4) 55 j = self.checkScript(test_peephole, (a,)) 56 r1, r2 = j(a) 57 self.assertNotEqual(r1.data_ptr(), r2.data_ptr()) 58 59 def test_peephole(self): 60 a = torch.tensor([0.4]) 61 b = torch.tensor([0.7]) 62 c = torch.tensor([0], dtype=torch.int32) 63 64 def f(x, y): 65 return x.type_as(y) 66 67 tf = torch.jit.trace(f, (a, b)) 68 FileCheck().check("type_as").run(str(tf.graph)) 69 self.run_pass("peephole", tf.graph) 70 FileCheck().check_not("type_as").run(str(tf.graph)) 71 tf2 = torch.jit.trace(f, (a, c)) 72 s = str(tf2.graph) 73 self.run_pass("peephole", tf2.graph) 74 self.assertEqual(s, str(s)) 75 76 def test_peephole_dynamic(self): 77 def f(x, y): 78 return x.type_as(y) 79 80 fn = torch.jit.script(f) 81 s = str(fn.graph) 82 torch._C._jit_pass_peephole(fn.graph) 83 self.assertEqual(s, str(fn.graph)) 84 85 def test_peephole_list_ops(self): 86 @torch.jit.script 87 def foo(x, y, z): 88 return len([x, y, z]) 89 90 self.run_pass("peephole", foo.graph) 91 FileCheck().check("value=3").check_next("return").run(foo.graph) 92 93 @torch.jit.script 94 def foo(x, y, z): 95 li = [x, y, z] 96 for i in range(len(x)): 97 li.append(x) 98 return len([x, y, z]) 99 100 self.run_pass("peephole", foo.graph) 101 FileCheck().check_not("aten::len").run(foo.graph) 102 103 @torch.jit.script 104 def foo(x, y, z): 105 li = [x, y, z] 106 return li[1], li[-2] 107 108 FileCheck().check("aten::__getitem__").run(foo.graph) 109 self.run_pass("peephole", foo.graph) 110 FileCheck().check_not("aten::__getitem__").run(foo.graph) 111 112 @torch.jit.script 113 def foo(x, y, z): 114 li = [x, y, z] 115 return li[-7] 116 117 self.run_pass("peephole", foo.graph) 118 FileCheck().check("aten::__getitem__").run(foo.graph) 119 120 @torch.jit.script 121 def foo(x, y, z): 122 li = [x, y, z] 123 for i in range(len(x)): 124 li.append(x) 125 return li[-2] 126 127 self.run_pass("peephole", foo.graph) 128 FileCheck().check("aten::__getitem__").run(foo.graph) 129 130 @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") 131 def test_peephole_cuda(self): 132 a = torch.tensor([0.4], device="cpu") 133 b = torch.tensor([0.7], device="cuda") 134 c = torch.tensor([0.7], device="cuda") 135 136 def f(x, y): 137 return x.type_as(y) 138 139 trace = torch.jit.trace(f, (a, c)) 140 s = str(trace.graph) 141 self.run_pass("peephole", trace.graph) 142 self.assertEqual(s, str(trace.graph)) 143 trace = torch.jit.trace(f, (b, c)) 144 self.run_pass("peephole", trace.graph) 145 self.run_pass("dce", trace.graph) 146 FileCheck().check_not("type_as").run(str(trace.graph)) 147 148 @_inline_everything 149 def test_peephole_type_refinements(self): 150 def refine(x): 151 # type: (Optional[Tensor]) -> Tensor 152 return x if x is not None else torch.tensor(3) 153 154 @torch.jit.script 155 def test(): 156 return refine(torch.tensor(4)) 157 158 FileCheck().check("prim::unchecked_cast").run(test.graph) 159 self.run_pass("peephole", test.graph) 160 FileCheck().check_not("prim::unchecked_cast").run(test.graph) 161 162 # refinement not optimzied out 163 def is_int_tensor(x): 164 scalar = x.item() 165 if isinstance(scalar, int): 166 return scalar + 3 167 else: 168 return 8 169 170 self.checkScript(is_int_tensor, (torch.tensor(2),)) 171 self.checkScript(is_int_tensor, (torch.tensor(2.5),)) 172 graph = torch.jit.script(is_int_tensor).graph 173 self.run_pass("peephole", graph) 174 FileCheck().check("prim::unchecked_cast").run(graph) 175 176 def test_short_circuit_optimization(self): 177 @torch.jit.script 178 def const_expressions(x): 179 # type: (int) -> Tuple[bool, bool] 180 return x == 1 and False, x == 1 or True 181 182 self.run_pass("constant_propagation", const_expressions.graph) 183 FileCheck().check_not("prim::If").check_not("aten::eq").run( 184 const_expressions.graph 185 ) 186 self.assertEqual(const_expressions(1), (False, True)) 187 188 @torch.jit.script 189 def redundant_expressions(x): 190 # type: (int) -> Tuple[bool, bool] 191 return x == 1 and True, x == 1 or False 192 193 self.run_pass("peephole", redundant_expressions.graph) 194 self.assertEqual(redundant_expressions(1), (True, True)) 195 self.assertEqual(redundant_expressions(0), (False, False)) 196 # and True / or False are removed from graph 197 FileCheck().check("aten::eq").check_not("prim::If").run( 198 redundant_expressions.graph 199 ) 200 201 def test_conv_dim_folding(self): 202 modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d] 203 for mod in modules: 204 205 class ConvDim(torch.nn.Module): 206 def __init__(self) -> None: 207 super().__init__() 208 self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False) 209 210 def forward(self, x): 211 x = self.conv(x) 212 return x.dim() 213 214 conv_dim = torch.jit.script(ConvDim()) 215 self.run_pass("inline", conv_dim.graph) 216 self.run_pass("peephole", conv_dim.graph) 217 FileCheck().check_not("conv").check_not("dim").run(conv_dim.graph) 218 219 class ConvDimMutate(torch.nn.Module): 220 def __init__(self) -> None: 221 super().__init__() 222 self.conv = mod(3, 32, kernel_size=3, stride=2, bias=False) 223 224 def forward(self, x): 225 x = self.conv(x) 226 x.resize_([4, 4]) 227 return x.dim() 228 229 conv_dim = torch.jit.script(ConvDimMutate()) 230 self.run_pass("inline", conv_dim.graph) 231 self.run_pass("peephole", conv_dim.graph) 232 FileCheck().check("conv").check("dim").run(conv_dim.graph) 233 234 def test_normalized_rsub(self): 235 a = torch.tensor([1, 2, 3]) 236 b = torch.tensor([4, 5, 6]) 237 238 def convertible_rsub(x, y): 239 return (x - y), torch.rsub(y, x) 240 241 self.checkScript(convertible_rsub, (a, b)) 242 op_graph = torch.jit.script(convertible_rsub).graph 243 FileCheck().check_count("aten::sub", 2, exactly=True).run(op_graph) 244 FileCheck().check_count("aten::rsub", 0, exactly=True).run(op_graph) 245 246 def test_normalized_is_op(self): 247 def convertible_is_op(x: bool, y: bool): 248 return x is True, False is x, x is y 249 250 self.checkScript(convertible_is_op, (True, False)) 251 252 op_graph = torch.jit.script(convertible_is_op).graph 253 FileCheck().check_count("aten::eq", 3, exactly=True).run(op_graph) 254 FileCheck().check_count("aten::__is__", 0, exactly=True).run(op_graph) 255 256 def test_normalized_isnot_op(self): 257 def convertible_isnot_op(x: bool, y: bool): 258 return x is not True, False is not x, x is not y 259 260 self.checkScript(convertible_isnot_op, (True, False)) 261 262 op_graph = torch.jit.script(convertible_isnot_op).graph 263 FileCheck().check_count("aten::ne", 3, exactly=True).run(op_graph) 264 FileCheck().check_count("aten::__isnot__", 0, exactly=True).run(op_graph) 265 266 def test_peephole_list_len(self): 267 def run_peephole_and_check_const_value(graph, const_string): 268 torch._C._jit_pass_peephole_list_idioms(graph, refine_list_len=True) 269 self.run_pass("constant_propagation", graph) 270 FileCheck().check(const_string).check_next("return").run(graph) 271 272 def gen_li(inp_len: int): 273 return [0 for i in range(inp_len)] 274 275 @torch.jit.script 276 def foo(x: List[int], y: List[int]): 277 if len(x) != 4 or len(y) != 5: 278 raise Exception("") # noqa: TRY002 279 280 return len(x) + len(y) 281 282 run_peephole_and_check_const_value(foo.graph, "value=9") 283 self.assertEqual(foo(gen_li(4), gen_li(5)), 9) 284 with self.assertRaises(Exception): 285 foo(2, 4) 286 287 @torch.jit.script 288 def foo(x: List[int], y: List[int]): 289 if len(x) == 4 and len(y) == 5: 290 pass 291 else: 292 raise Exception("hi") # noqa: TRY002 293 294 return len(x) + len(y) 295 296 run_peephole_and_check_const_value(foo.graph, "value=9") 297 self.assertEqual(foo(gen_li(4), gen_li(5)), 9) 298 with self.assertRaises(Exception): 299 foo(2, 4) 300 301 @torch.jit.script 302 def foo(x: List[int], y: List[int], z: List[int]): 303 if len(x) != 4: 304 raise Exception("..") # noqa: TRY002 305 else: 306 if len(y) != 8: 307 raise Exception("...") # noqa: TRY002 308 else: 309 if len(z) == 3: 310 pass 311 else: 312 raise Exception("...") # noqa: TRY002 313 314 return len(x) + len(y) * len(z) 315 316 run_peephole_and_check_const_value(foo.graph, "value=28") 317 self.assertEqual(foo(gen_li(4), gen_li(8), gen_li(3)), 28) 318 with self.assertRaises(Exception): 319 foo(1, 2, 3) 320 321 # refinement should persist in second len(x) call 322 323 @torch.jit.script 324 def foo(x: List[int], cond: bool): 325 if len(x) == 4: 326 if cond: 327 return len(x) 328 return 4 329 330 return 4 331 332 run_peephole_and_check_const_value(foo.graph, "value=4") 333 334 def test_const_tuple_output(graph, const_inputs): 335 tup = graph.findNode("prim::TupleConstruct") 336 for i, elem in enumerate(tup.inputs()): 337 if i in const_inputs: 338 self.assertIsNotNone(elem.toIValue()) 339 else: 340 self.assertIsNone(elem.toIValue()) 341 342 # testing combinations of x1 : {True, False} x 343 # {then/else branch} x assert {True/False} 344 345 @torch.jit.script 346 def foo(x: List[int], b: List[int]): 347 if len(x) == 5: 348 x1 = True 349 else: 350 x1 = len(b) != 4 351 assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 352 return len(x), len(b) 353 354 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 355 torch._C._jit_pass_constant_propagation(foo.graph) 356 # we can only infer len(b) == 4 here 357 test_const_tuple_output(foo.graph, [1]) 358 359 @torch.jit.script 360 def foo(x: List[int], b: List[int]): 361 if len(x) == 5: 362 x1 = False 363 else: 364 x1 = len(b) != 4 365 assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 366 return len(x), len(b) 367 368 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 369 torch._C._jit_pass_constant_propagation(foo.graph) 370 # cant infer anything 371 test_const_tuple_output(foo.graph, []) 372 373 @torch.jit.script 374 def foo(x: List[int], b: List[int]): 375 if len(x) == 5: 376 x1 = True 377 else: 378 x1 = len(b) == 4 379 assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 380 return len(x), len(b) 381 382 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 383 torch._C._jit_pass_constant_propagation(foo.graph) 384 # we cant infer anything, only len(b) != 4 385 test_const_tuple_output(foo.graph, []) 386 387 @torch.jit.script 388 def foo(x: List[int], b: List[int]): 389 if len(x) == 5: 390 x1 = True 391 else: 392 x1 = len(b) != 4 393 assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 394 return len(x), len(b) 395 396 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 397 torch._C._jit_pass_constant_propagation(foo.graph) 398 # can infer len(b) == 4 399 test_const_tuple_output(foo.graph, [1]) 400 401 # swap branches 402 @torch.jit.script 403 def foo(x: List[int], b: List[int]): 404 if len(x) != 5: 405 x1 = len(b) != 4 406 else: 407 x1 = True 408 assert x1 == False # noqa: E712 TODO: canonicalize x is False to aten::eq 409 return len(x), len(b) 410 411 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 412 torch._C._jit_pass_constant_propagation(foo.graph) 413 # can infer len(b) == 4 414 test_const_tuple_output(foo.graph, [1]) 415 416 # use __not__ 417 @torch.jit.script 418 def foo(x: List[int], b: List[int]): 419 if len(x) != 5: 420 x1 = len(b) != 4 421 else: 422 x1 = True 423 assert not x1 424 return len(x), len(b) 425 426 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 427 torch._C._jit_pass_constant_propagation(foo.graph) 428 # can infer len(b) == 4 429 test_const_tuple_output(foo.graph, [1]) 430 431 # Test unsuccessful optimizations 432 433 @torch.jit.script 434 def foo(x: List[int]): 435 assert len(x) == 4 436 x.append(3) 437 return len(x) 438 439 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 440 self.run_pass("constant_propagation", foo.graph) 441 FileCheck().check_count("aten::len", 2).run(foo.graph) 442 443 @torch.jit.script 444 def foo(x: List[int], y: List[int]): 445 assert len(x) == 4 or len(y) == 5 446 return len(x) + len(y) 447 448 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 449 self.run_pass("constant_propagation", foo.graph) 450 FileCheck().check_count("aten::len", 4).run(foo.graph) 451 452 def test_integer_refinement(self): 453 def run_peephole_and_check_const_value(graph, const_string): 454 self.run_pass("refine_integer_values", graph) 455 self.run_pass("constant_propagation", graph) 456 self.run_pass("dce", graph) 457 FileCheck().check(const_string).check_next("return").run(graph) 458 459 @torch.jit.script 460 def foo(x: int, y: int): 461 if x != 4 or y != 5: 462 raise Exception("") # noqa: TRY002 463 464 return x + y 465 466 graph = foo.graph 467 self.run_pass("refine_integer_values", graph) 468 self.run_pass("constant_propagation", graph) 469 self.run_pass("dce", graph) 470 471 run_peephole_and_check_const_value(foo.graph, "value=9") 472 self.assertEqual(foo(4, 5), 9) 473 with self.assertRaises(Exception): 474 foo(2, 4) 475 476 @torch.jit.script 477 def foo(x: int, y: int): 478 if x == 4 and y == 5: 479 pass 480 else: 481 raise Exception("hi") # noqa: TRY002 482 483 return x + y 484 485 run_peephole_and_check_const_value(foo.graph, "value=9") 486 self.assertEqual(foo(4, 5), 9) 487 with self.assertRaises(Exception): 488 foo(2, 4) 489 490 @torch.jit.script 491 def foo(x: int, y: int, z: int): 492 if x != 4: 493 raise Exception("..") # noqa: TRY002 494 else: 495 if y != 8: 496 raise Exception("...") # noqa: TRY002 497 else: 498 if z == 3: 499 pass 500 else: 501 raise Exception("...") # noqa: TRY002 502 503 return x + y * z 504 505 run_peephole_and_check_const_value(foo.graph, "value=28") 506 self.assertEqual(foo(4, 8, 3), 28) 507 with self.assertRaises(Exception): 508 foo(1, 2, 3) 509 510 # refinement should persist in second len(x) call 511 512 @torch.jit.script 513 def foo(x: int, cond: bool): 514 if x == 4: 515 if cond: 516 return x 517 return 4 518 519 return 4 520 521 run_peephole_and_check_const_value(foo.graph, "value=4") 522 523 @torch.jit.script 524 def foo(x: int, y: int): 525 assert x == 4 or y == 5 526 return x + y 527 528 torch._C._jit_pass_peephole_list_idioms(foo.graph, refine_list_len=True) 529 self.run_pass("constant_propagation", foo.graph) 530 FileCheck().check("aten::add").run(foo.graph) 531 532 def test_optimize_out_comparison_same_value(self): 533 def foo(x: int): 534 return x == x, x != x 535 536 def foo2(x: List[int]): 537 return x == x, x != x 538 539 for func, inp in zip([foo, foo2], [1, [2, 3]]): 540 func_s = torch.jit.script(func) 541 self.run_pass("peephole", func_s.graph) 542 FileCheck().check_not("aten::eq").check_not("aten::neq").run(func_s.graph) 543 self.assertEqual(func(inp), func_s(inp)) 544 545 def test_peephole_add_zero(self): 546 @torch.jit.script 547 def foo(x: int): 548 return x + 0, 0 + x 549 550 self.run_pass("peephole", foo.graph) 551 FileCheck().check_not("aten::add") 552 self.assertEqual(foo(3), (3, 3)) 553 554 def test_noop_peephole(self): 555 # test unsuccessful 556 def foo1(x): 557 return x + 0 558 559 def foo2(): 560 x = torch.zeros([2, 2]) 561 x.sub_(3) 562 return x + 0 563 564 def foo3(): 565 x = torch.zeros([2, 2]) 566 return x, x + 0 567 568 def foo4(): 569 x = torch.zeros([2, 2]) 570 return x + 0.0 571 572 funcs = foo1, foo2, foo3, foo4 573 inps = (torch.ones([2]),), (), (), () 574 for func, inp in zip(funcs, inps): 575 foo_s = torch.jit.script(func) 576 self.run_pass("peephole", foo_s.graph) 577 FileCheck().check_count("aten::add", 1, exactly=True).run(foo_s.graph) 578 self.assertEqual(func(*inp), foo_s(*inp)) 579 580 # successful 581 def func(x): 582 return (x + 0) * 1 - 5 583 584 func_s = torch.jit.script(func) 585 self.run_pass("peephole", func_s.graph) 586 # bail on modified value first 587 FileCheck().check_not("aten::add").check("aten::mul").run(func_s.graph) 588 # second run it should succeed 589 self.run_pass("peephole", func_s.graph) 590 FileCheck().check_not("aten::add").check_not("aten::mul").run(func_s.graph) 591 self.assertEqual(func(torch.ones([2, 2])), func_s(torch.ones([2, 2]))) 592 593 def func(x): 594 return (x + 0.0) - 5 595 596 func_s = torch.jit.script(func) 597 inp = next(func_s.graph.inputs()) 598 inp.setType(torch._C.TensorType.create_from_tensor(torch.rand([2, 2]))) 599 torch._C._jit_pass_peephole(func_s.graph, disable_shape_peepholes=True) 600 FileCheck().check("aten::add").run(func_s.graph) 601 torch._C._jit_pass_peephole(func_s.graph, disable_shape_peepholes=False) 602 FileCheck().check_not("aten::add").run(func_s.graph) 603 604 def test_refine_integer_values(self): 605 @torch.jit.script 606 def foo(x: int): 607 y = 1 608 if x == 1: 609 return y 610 else: 611 return x 612 613 self.run_pass("refine_integer_values", foo.graph) 614 self.run_pass("constant_propagation", foo.graph) 615 self.run_pass("dce", foo.graph) 616 FileCheck().check("graph").check_next("return").run(foo.graph) 617 self.assertEqual(foo(2), 2) 618 self.assertEqual(foo(1), 1) 619 620 def test_peephole_len_list(self): 621 @torch.jit.script 622 def foo(x): 623 return len(x.size()) 624 625 self.run_pass("peephole", foo.graph) 626 FileCheck().check("aten::len").run(foo.graph) 627 inputs = list(foo.graph.inputs()) 628 inputs[0].setType(inputs[0].type().with_sizes([None, None])) 629 self.run_pass("peephole", foo.graph) 630 FileCheck().check_not("aten::len").run(foo.graph) 631 self.assertEqual(2, foo(torch.rand([3, 1]))) 632 633 @torch.jit.script 634 def foo(x): 635 li = x.size() 636 li.append(4) 637 return len(li) 638 639 inputs = list(foo.graph.inputs()) 640 inputs[0].setType(inputs[0].type().with_sizes([None, None])) 641 self.run_pass("peephole", foo.graph) 642 FileCheck().check("aten::len").run(foo.graph) 643 self.assertEqual(3, foo(torch.rand([3, 1]))) 644 645 def test_peephole_optional_refine(self): 646 @torch.jit.script 647 def foo(z: int, z2: int, cond: bool): 648 if cond: 649 return z 650 else: 651 return z2 652 653 out = next(foo.graph.findNode("prim::If").outputs()) 654 out.setType(torch._C.OptionalType(torch._C.IntType.get())) 655 self.run_pass("peephole", foo.graph) 656 FileCheck().check_not("int?").run(foo.graph) 657 658 def test_peephole_int(self): 659 @torch.jit.script 660 def foo(x): 661 # type: (number) 662 return int(x) 663 664 FileCheck().check("aten::Int").run(foo.graph) 665 next(foo.graph.inputs()).setType(torch._C.IntType.get()) 666 self.run_pass("peephole", foo.graph) 667 FileCheck().check_not("aten::Int").run(foo.graph) 668 669 def test_peephole_arith(self): 670 @torch.jit.script 671 def foo(input0: int, input1: int, input2: int, input3: int): 672 _1 = torch.add(input1, 2) 673 _3 = torch.add(input3, 2) 674 _5 = torch.add(1, torch.sub(_1, 3) // 1) 675 _6 = torch.add(1 * torch.sub(_3, 3) // 1, 1) / 1 676 return [_5, int(_6)] 677 678 FileCheck().check("aten::add").check("aten::sub").check("aten::mul").check( 679 "aten::floordiv" 680 ).check("aten::div").run(foo.graph) 681 self.run_pass("peephole", foo.graph) 682 FileCheck().check("graph").check("):").check_next("ListConstruct").check_next( 683 "return" 684 ).run(foo.graph) 685 self.assertEqual(foo(0, 1, 2, 3), [1, 3]) 686 687 def test_peephole_dict_getitem_simple(self): 688 @torch.jit.script 689 def foo(a: int, b: int): 690 d = {0: a, 1: b} 691 x = d[1] 692 y = d[0] 693 return x, y 694 695 self.run_pass("peephole", foo.graph) 696 FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) 697 self.assertEqual(foo(0, 1), (1, 0)) 698 699 @torch.jit.script 700 def foo(a: int, b: int): 701 d = {"0": a, "1": b} 702 x = d["1"] 703 y = d["0"] 704 return x, y 705 706 self.run_pass("peephole", foo.graph) 707 FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) 708 self.assertEqual(foo(0, 1), (1, 0)) 709 710 @torch.jit.script 711 def foo(a: int, b: int): 712 d = {0.0: a, 1.0: b} 713 x = d[1.0] 714 y = d[0.0] 715 return x, y 716 717 self.run_pass("peephole", foo.graph) 718 FileCheck().check_not("DictConstruct").check_not("__getitem__").run(foo.graph) 719 self.assertEqual(foo(0, 1), (1, 0)) 720 721 def test_peephole_dict_getitem_no_optimization_missing_key(self): 722 @torch.jit.script 723 def foo(): 724 d = {0: 1} 725 return d[2] 726 727 self.run_pass("peephole", foo.graph) 728 FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 729 730 def test_peephole_dict_getitem_no_optimization_get_input_arg(self): 731 # Here we don't know if the input arg is in the dict, so we can't 732 # make the optimization. 733 @torch.jit.script 734 def foo(a: int): 735 d = {0: 1} 736 return d[a] 737 738 self.run_pass("peephole", foo.graph) 739 FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 740 self.assertEqual(foo(0), 1) 741 742 def test_peephole_dict_getitem_no_optimization_dict_modified(self): 743 @torch.jit.script 744 def foo(): 745 d = {0: 1} 746 d[0] = 2 747 return d[0] 748 749 self.run_pass("peephole", foo.graph) 750 FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 751 self.assertEqual(foo(), 2) 752 753 def test_peephole_dict_getitem_no_optimization_overlapping_keys(self): 754 @torch.jit.script 755 def foo(): 756 d = {0: 1, 0: 2} # noqa: F601 757 return d[0] 758 759 self.run_pass("peephole", foo.graph) 760 FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 761 762 def test_peephole_dict_getitem_no_optimization_keys_might_overlap(self): 763 @torch.jit.script 764 def foo(x: int): 765 d = {0: 1, x: 2} 766 return d[x] 767 768 self.run_pass("peephole", foo.graph) 769 FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 770 771 def test_peephole_dict_getitem_no_optimization_unsupported_type(self): 772 @torch.jit.script 773 def foo(): 774 a = torch.rand((2, 2)) 775 d = {a: 1} 776 return d[a] 777 778 self.run_pass("peephole", foo.graph) 779 FileCheck().check("DictConstruct").check("__getitem__").run(foo.graph) 780 self.assertEqual(foo(), 1) 781 782 def test_peephole_dict_len(self): 783 @torch.jit.script 784 def foo(): 785 d = {0: 1, 1: 2} 786 return len(d) 787 788 self.run_pass("peephole", foo.graph) 789 FileCheck().check_not("DictConstruct").check_not("len").run(foo.graph) 790 self.assertEqual(foo(), 2) 791 792 def test_peephole_dict_len_no_optimization_overlapping_keys(self): 793 @torch.jit.script 794 def foo(): 795 d = {0: 1, 0: 2} # noqa: F601 796 return len(d) 797 798 self.run_pass("peephole", foo.graph) 799 FileCheck().check("DictConstruct").check("len").run(foo.graph) 800 self.assertEqual(foo(), 1) 801 802 def test_peephole_dict_len_no_optimization_keys_might_overlap(self): 803 @torch.jit.script 804 def foo(x: int): 805 d = {0: 1, x: 2} 806 return len(d) 807 808 self.run_pass("peephole", foo.graph) 809 FileCheck().check("DictConstruct").check("len").run(foo.graph) 810 811 def test_peephole_dict_len_no_optimization_unsupported_type(self): 812 @torch.jit.script 813 def foo(): 814 a = torch.rand((2, 2)) 815 d = {a: 1} 816 return len(d) 817 818 self.run_pass("peephole", foo.graph) 819 FileCheck().check("DictConstruct").check("len").run(foo.graph) 820 self.assertEqual(foo(), 1) 821 822 def test_peephole_slice_all_three_args(self): 823 def foo(x: int): 824 return [1, 2, x, 4, 5, 6, 7][-5:6:2] 825 826 graph = torch.jit.script(foo).graph 827 self.run_pass("peephole", graph) 828 FileCheck().check_not("aten::slice").run(graph) 829 self.checkScript(foo, (3,)) 830 831 def test_peephole_slice_one_empty_arg(self): 832 def check_helper(fn: Callable[[int], None]) -> None: 833 graph = torch.jit.script(fn).graph 834 self.run_pass("peephole", graph) 835 FileCheck().check_not("aten::slice").run(graph) 836 self.checkScript(fn, (3,)) 837 838 def foo(x: int): 839 return [1, 2, x, 4, 5, 6, 7][1::2] 840 841 check_helper(foo) 842 843 def foo(x: int): 844 return [1, 2, x, 4, 5, 6, 7][:5:3] 845 846 check_helper(foo) 847 848 def foo(x: int): 849 return [1, 2, x, 4, 5, 6, 7][0:4] 850 851 check_helper(foo) 852 853 def test_peephole_slice_two_empty_args(self): 854 def check_helper(fn: Callable[[int], None]) -> None: 855 graph = torch.jit.script(fn).graph 856 self.run_pass("peephole", graph) 857 FileCheck().check_not("aten::slice").run(graph) 858 self.checkScript(fn, (3,)) 859 860 def foo(x: int): 861 return [1, 2, x, 4, 5, 6, 7][::2] 862 863 check_helper(foo) 864 865 def foo(x: int): 866 return [1, 2, x, 4, 5, 6, 7][:5] 867 868 check_helper(foo) 869 870 def foo(x: int): 871 return [1, 2, x, 4, 5, 6, 7][1:] 872 873 check_helper(foo) 874 875 def test_peephole_slice_optimization_not_applied_list_modified(self): 876 @torch.jit.script 877 def foo(): 878 li = [1, 2, 3, 4, 5, 6, 7] 879 li[0] = 0 880 return li[2:5] 881 882 self.run_pass("peephole", foo.graph) 883 FileCheck().check("aten::slice").run(foo.graph) 884 885 def test_peephole_slice_optimization_not_applied_non_const_args(self): 886 @torch.jit.script 887 def foo(x: int, y: int): 888 li = [1, 2, 3, 4, 5, 6, 7] 889 return li[x:y] 890 891 self.run_pass("peephole", foo.graph) 892 FileCheck().check("aten::slice").run(foo.graph) 893