1# Owner(s): ["module: dynamo"] 2from unittest.mock import patch 3 4import torch 5import torch._dynamo.test_case 6import torch._dynamo.testing 7from torch._dynamo.testing import unsupported 8from torch._dynamo.utils import ifdynstaticdefault 9 10 11globalmod = torch.nn.ReLU() 12 13 14def indirectly_unsupported(a, b): 15 c = a + b 16 return unsupported(a, c) 17 18 19class SubGraphTests(torch._dynamo.test_case.TestCase): 20 def _common(self, fn, frame_count, op_count): 21 torch._dynamo.reset() 22 v1 = torch.ones(10) 23 v2 = torch.ones(10) * -2.0 24 correct1 = fn(v1, v2) 25 correct2 = fn(v2, v1) 26 cnt = torch._dynamo.testing.CompileCounter() 27 opt_fn = torch._dynamo.optimize(cnt)(fn) 28 r1 = opt_fn(v1, v2) 29 r2 = opt_fn(v2, v1) 30 self.assertTrue(torch._dynamo.testing.same(r1, correct1)) 31 self.assertTrue(torch._dynamo.testing.same(r2, correct2)) 32 self.assertEqual( 33 cnt.frame_count, 34 frame_count, 35 f"actual {cnt.frame_count} != expected {frame_count}", 36 ) 37 self.assertEqual(cnt.op_count, op_count) 38 39 def test_control_flow1(self): 40 def fn(a, b): 41 c1 = a - b 42 c2 = b - a 43 if c1.sum() > c2.sum(): 44 return c1 45 else: 46 return c2 47 48 self._common(fn, 1, 5) 49 50 def test_control_flow2(self): 51 def fn(a, b): 52 if a.sum() > b.sum(): 53 return 1 54 else: 55 return 2 56 57 self._common(fn, 1, 3) 58 59 def test_control_flow3(self): 60 def fn(a, b): 61 c1 = a - b 62 c2 = b - a 63 m = globalmod 64 if c1.sum() > c2.sum(): 65 return m(c1) 66 else: 67 return m(c2) 68 69 self._common(fn, 3, 7) 70 71 def test_control_flow4(self): 72 def fn(a, b): 73 tmp1 = a.sum() > b.sum() and a.sum() > 0 74 if tmp1: 75 return 1 76 else: 77 return 2 78 79 self._common(fn, 3, 5) 80 81 def test_control_flow5(self): 82 def fn(a, b): 83 tmp1 = a.sum() > b.sum() and a.sum() > 0 84 tmp2 = a.sum() < b.sum() or b.sum() > 0 85 if tmp1 and tmp2: 86 return 1, tmp1, tmp2 87 else: 88 return 2, tmp1, tmp2 89 90 self._common(fn, 6, 13) 91 92 def test_capi_call1(self): 93 def fn(a, b): 94 c1 = a - b 95 c2 = b - a 96 return unsupported(c1, c2) 97 98 self._common(fn, 1, 2) 99 100 def test_capi_call2(self): 101 def fn(a, b): 102 c1 = a - b 103 c2 = b - a 104 return a - (b - unsupported(c1, c2)) 105 106 self._common(fn, 2, 4) 107 108 def test_capi_call3(self): 109 def fn(a, b): 110 c1 = a - b 111 c2 = b - a 112 return torch._dynamo.testing.unsupported(c1, c2) 113 114 self._common(fn, 1, 2) 115 116 def test_indirect_unsupported1(self): 117 def fn(a, b): 118 c1 = a - b 119 c2 = b - a 120 return indirectly_unsupported(c1, c2) 121 122 self._common(fn, 2, 3) 123 124 def test_indirect_unsupported2(self): 125 def fn(a, b): 126 local_const1 = 7 127 local_const2 = 22 128 c1 = a - b 129 c2 = b - a 130 return local_const1 / (local_const2 - indirectly_unsupported(c1, c2)) 131 132 self._common(fn, 3, 5) 133 134 def test_indirect_unsupported3(self): 135 def fn(a, b): 136 args = [a - b, b - a] 137 return indirectly_unsupported(*args) 138 139 self._common(fn, 2, 3) 140 141 def test_stack_state1(self): 142 def fn(a, b): 143 t1 = 1.23 * a 144 t2 = 4.56 * a 145 c1 = a - b 146 c2 = b - a 147 return t1 / (t2 - unsupported(c1, c2)) 148 149 self._common(fn, 2, 6) 150 151 def test_stack_state2(self): 152 def fn(a, b): 153 t1 = 1.23 * a 154 t2 = 4.56 * a 155 c1 = a - b 156 c2 = b - a 157 return t1 / (t2 - indirectly_unsupported(c1, c2)) 158 159 self._common(fn, 3, 7) 160 161 def test_multigraph(self): 162 def fn(a, b): 163 x = a + b 164 x = x / 2.0 165 if x.sum() < 0: 166 return x * -1.0 167 return x 168 169 self._common(fn, 2, 5) 170 171 def test_extended_args(self): 172 too_many_adds = "+".join(["a", "b"] * 256) 173 source = ( 174 f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" 175 ) 176 self._common(eval(source), 3, 1026) 177 178 def test_resume1(self): 179 def fn(a, b): 180 x = a + b 181 x = x / 2.0 182 x = x + 2.0 183 x = unsupported(x, a) 184 x = x + 2.0 185 x = x + 2.0 186 x = x + 2.0 187 return x 188 189 self._common(fn, 2, 6) 190 191 def test_resume2(self): 192 def fn(a, b): 193 x = a + b 194 x = x / 2.0 195 x = x + 2.0 196 x = indirectly_unsupported(x, a) 197 x = x + 2.0 198 x = x + 2.0 199 x = x + 2.0 200 return x 201 202 self._common(fn, 3, 7) 203 204 def test_resume3(self): 205 def fn(a, b): 206 x = a + b 207 x = x / 2.0 208 x = x + 2.0 209 x = indirectly_unsupported(x, b=a) 210 x = x + 2.0 211 x = x + 2.0 212 x = x + 2.0 213 return x 214 215 self._common(fn, 3, 7) 216 217 def test_resume4(self): 218 def fn(a, b): 219 x = a + b 220 x = x / 2.0 221 x = x + 2.0 222 x = indirectly_unsupported(a=x, b=a) 223 x = x + 2.0 224 x = x + 2.0 225 x = x + 2.0 226 return x 227 228 self._common(fn, 3, 7) 229 230 def test_resume5(self): 231 def fn(a, b): 232 x = a + b 233 x = x / 2.0 234 x = x + 2.0 235 print(x) 236 x = x + 2.0 237 x = x + 2.0 238 x = x + 2.0 239 return x 240 241 self._common(fn, 2, 6) 242 243 def test_start1(self): 244 def fn(a, b): 245 print(a) 246 x = a + b 247 x = x + 2.0 248 x = x + 2.0 249 return x 250 251 self._common(fn, 1, 3) 252 253 def test_start2(self): 254 def fn(a, b): 255 x = indirectly_unsupported(a, b) 256 x = x + 2.0 257 x = x + 2.0 258 x = x + 2.0 259 return x 260 261 self._common(fn, 2, 4) 262 263 def test_start3(self): 264 def fn(a, b): 265 x = unsupported(a, b) 266 x = x + 2.0 267 x = x + 2.0 268 x = x + 2.0 269 return x 270 271 self._common(fn, 1, 3) 272 273 def test_start4(self): 274 def fn(a, b, check): 275 if check: 276 return a + b + 10 277 else: 278 return a + b - 10 279 280 v1 = torch.randn(10) 281 v2 = torch.randn(10) 282 f = torch.zeros(1, dtype=torch.int32) 283 t = torch.ones(1, dtype=torch.int32) 284 correct1 = fn(v1, v2, t) 285 correct2 = fn(v1, v2, f) 286 cnt = torch._dynamo.testing.CompileCounter() 287 opt_fn = torch._dynamo.optimize(cnt)(fn) 288 r1 = opt_fn(v1, v2, t) 289 r2 = opt_fn(v1, v2, f) 290 self.assertTrue(torch._dynamo.testing.same(r1, correct1)) 291 self.assertTrue(torch._dynamo.testing.same(r2, correct2)) 292 self.assertEqual(cnt.frame_count, 3) 293 self.assertEqual(cnt.op_count, 4) 294 295 def test_resume_freevars(self): 296 c1 = torch.randn(10) 297 c2 = torch.randn(10) 298 299 def fn(a, b): 300 x = a + b + (c1 - c2) 301 x = unsupported(x, x) 302 return x + (c1 - c2) 303 304 self._common(fn, 2, 5) 305 306 def test_restore_state(self): 307 def fn(a, b): 308 len_ = len 309 x = a + b 310 x = torch.add(unsupported(x, x), 1) 311 return a * x + len_(b) 312 313 self._common(fn, 2, 4) 314 315 def test_restore_range(self): 316 def fn(a, b): 317 x = a + b 318 rng = range(3, 8, 2) 319 x = unsupported(x, x) 320 for i in rng: 321 x = x + i 322 return x 323 324 # We don't specialize on range with dynamic shapes, which 325 # means we fail to unroll the loop. 326 # TODO: Consider forcing specialization when we iterate over 327 # the loop 328 self._common(fn, ifdynstaticdefault(2, 1), ifdynstaticdefault(4, 1)) 329 330 def test_restore_range_iter(self): 331 def fn(a, b): 332 x = a + b 333 rng = iter(range(3, 8, 2)) 334 x = unsupported(x, x) 335 x += next(rng) 336 return x, list(rng) 337 338 self._common(fn, 2, 2) 339 340 def test_pop_after_resume(self): 341 def fn(a, b): 342 tmp = [a + 1, b + 2, a + b] 343 x = a 344 x = unsupported(x, x) 345 for i in range(3): 346 x += tmp.pop(-1) 347 return x 348 349 self._common(fn, 2, 6) 350 351 @patch("torch._dynamo.config.assume_static_by_default", False) 352 def test_dynamic_getitem(self): 353 def fn(a, b): 354 return a[b.size(0) - 1] 355 356 cnt = torch._dynamo.testing.CompileCounter() 357 opt_fn = torch._dynamo.optimize(cnt)(fn) 358 for i in range(3, 12): 359 opt_fn(torch.randn(i), torch.randn(i)) 360 # just one graph 361 self.assertEqual(cnt.frame_count, 1) 362 363 def test_dynamic_kwarg(self): 364 def fn(a, b): 365 return a - b * 10 366 367 torch._dynamo.reset() 368 cnt_dynamic = torch._dynamo.testing.CompileCounter() 369 opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) 370 start = 2 371 end = 12 372 steps = end - start 373 for i in range(start, end): 374 opt_fn(torch.randn(i), torch.randn(i)) 375 376 self.assertEqual(cnt_dynamic.frame_count, 1) 377 378 def test_dynamic_duck_size(self): 379 def fn(a, b): 380 if a.size(0) == b.size(0): 381 return a + b 382 else: 383 return a.sum() + b.sum() 384 385 torch._dynamo.reset() 386 cnt_dynamic = torch._dynamo.testing.CompileCounter() 387 opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) 388 x = torch.randn(2) 389 y = torch.randn(3) 390 self.assertEqual(opt_fn(x, x), fn(x, x)) 391 self.assertEqual(opt_fn(x, y), fn(x, y)) 392 self.assertEqual(cnt_dynamic.frame_count, 2) 393 394 def test_dynamic_order_dependence(self): 395 def fn(a, b): 396 return a.sum() + b.sum() 397 398 torch._dynamo.reset() 399 cnt_dynamic = torch._dynamo.testing.CompileCounter() 400 opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) 401 x = torch.randn(2) 402 y = torch.randn(3) 403 self.assertEqual(opt_fn(x, y), fn(x, y)) 404 self.assertEqual(opt_fn(x, x), fn(x, x)) 405 # NB: This COULD validly be 2, but we don't test disjointness in the 406 # guards for when x and y didn't duck size together, so we end up 407 # with a generic graph that also works when x and y happen to duck 408 # size together. 409 self.assertEqual(cnt_dynamic.frame_count, 2) 410 411 torch._dynamo.reset() 412 cnt_dynamic.frame_count = 0 413 self.assertEqual(opt_fn(x, x), fn(x, x)) # this overspecializes! 414 self.assertEqual(opt_fn(x, y), fn(x, y)) 415 self.assertEqual(cnt_dynamic.frame_count, 2) 416 417 def test_dynamic_zero_inference(self): 418 def fn(a): 419 if a.size(0) != 0: 420 return a * 2 421 else: 422 return a + 1 423 424 torch._dynamo.reset() 425 cnt_dynamic = torch._dynamo.testing.CompileCounter() 426 opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) 427 x = torch.randn(0) 428 y = torch.randn(2) 429 self.assertEqual(opt_fn(y), fn(y)) 430 self.assertEqual(opt_fn(x), fn(x)) 431 self.assertEqual(cnt_dynamic.frame_count, 2) 432 433 @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) 434 def test_no_graph_break_on_item(self): 435 def fn(a, b): 436 x = a + b - 1.5 437 x = x.sum() 438 x.item() 439 x = x / (a + b) 440 return x 441 442 self._common(fn, 1, 5) # item gets DCE'd 443 444 @patch.object(torch._dynamo.config, "capture_scalar_outputs", False) 445 def test_graph_break_on_item(self): 446 def fn(a, b): 447 x = a + b - 1.5 448 x = x.sum() 449 x.item() 450 x = x / (a + b) 451 return x 452 453 self._common(fn, 2, 5) 454 455 def test_resume_paths_join(self): 456 def fn(x, c1, c2, c3): 457 x = x + 1 458 if c1: 459 x = x + 2 460 x = x + 3 461 if c2: 462 x = x + 4 463 x = x + 5 464 if c3: 465 x = x + 6 466 return x + 7 467 468 v1 = torch.randn(10) 469 t = torch.Tensor([True]) 470 f = torch.Tensor([False]) 471 cnt = torch._dynamo.testing.CompileCounter() 472 opt_fn = torch._dynamo.optimize(cnt)(fn) 473 for a in (t, f): 474 for b in (t, f): 475 for c in (t, f): 476 opt_fn(v1, a, b, c) 477 478 # checking here we don't create 2^n graphs 479 self.assertEqual(cnt.frame_count, 7) 480 self.assertEqual(cnt.op_count, 10) 481 482 def test_resume_with_no_grad1(self): 483 def fn(a, b): 484 x = a + b 485 with torch.no_grad(): 486 x = x + 1 487 x.sum().tolist() # graph break 488 x = x + 2 489 x = x + 3 490 return x 491 492 self._common(fn, 2, 9) 493 torch._dynamo.reset() 494 with torch.no_grad(): 495 self._common(fn, 2, 5) 496 497 def test_resume_with_no_grad2(self): 498 def fn(a, b): 499 x = a + b 500 with torch.no_grad(): 501 x = x + 1 502 x.sum().tolist() # graph break 503 x = x + 2 504 x.sum().tolist() # graph break 505 x = x + 3 506 x = x + 4 507 return x 508 509 self._common(fn, 3, 13) 510 511 def test_resume_with_no_grad3(self): 512 def fn(a, b): 513 x = a + b 514 with torch.no_grad(): 515 with torch.no_grad(): 516 x = x + 1 517 with torch.enable_grad(): 518 x.sum().tolist() # graph break 519 x = x[0] + 2 520 x = x + 3 521 x = x + 4 522 return x 523 524 self._common(fn, 2, 11) 525 526 def test_resume_tuple_iterator(self): 527 def fn(a, b): 528 x = a + b 529 it = iter(tuple(range(10))) 530 x = x + next(it) 531 x = x + next(it) 532 x = x + next(it) 533 x = unsupported(x, x) 534 x = x + next(it) 535 x = x + next(it) 536 x = x + next(it) 537 x = x + next(it) 538 return x 539 540 self._common(fn, 2, 8) 541 542 def test_tuple_iterator_return(self): 543 def fn(x): 544 it = iter(tuple(range(10))) 545 x = x + next(it) 546 x = x + next(it) 547 x = unsupported(x, x) 548 x = x + next(it) 549 x = x + next(it) 550 x = unsupported(x, x) 551 x = x + next(it) 552 x = x + next(it) 553 return x, it 554 555 v1 = torch.randn(10) 556 v2, it2 = fn(v1) 557 cnt = torch._dynamo.testing.CompileCounter() 558 opt_fn = torch._dynamo.optimize(cnt)(fn) 559 v3, it3 = opt_fn(v1) 560 v4, it4 = opt_fn(v1) 561 self.assertEqual(v2.tolist(), v3.tolist()) 562 self.assertEqual(v2.tolist(), v4.tolist()) 563 self.assertEqual(list(it2), list(it3)) 564 self.assertEqual(cnt.frame_count, 3) 565 self.assertEqual(cnt.op_count, 6) 566 567 def test_tuple_iterator_mutate(self): 568 def fn(x, it): 569 x = x + next(it) 570 x = x + next(it) 571 x = x + next(it) 572 x = x + next(it) 573 return x 574 575 v1 = torch.randn(10) 576 it1 = iter(tuple(range(10))) 577 cnt = torch._dynamo.testing.CompileCounter() 578 opt_fn = torch._dynamo.optimize(cnt)(fn) 579 self.assertEqual(opt_fn(v1, it1).tolist(), (v1 + 1 + 2 + 3).tolist()) 580 self.assertEqual(list(it1), [4, 5, 6, 7, 8, 9]) 581 582 def test_enumerate_not_break_graph(self): 583 def fn(a, b): 584 for i, x in enumerate(a.shape): 585 b = b + x 586 for i, x in enumerate(b.shape, 8): 587 b = b + x * i 588 return b 589 590 self._common(fn, 1, ifdynstaticdefault(2, 3)) 591 592 593if __name__ == "__main__": 594 from torch._dynamo.test_case import run_tests 595 596 run_tests() 597