1# Owner(s): ["module: inductor"] 2import itertools 3import unittest 4 5import torch 6import torch._dynamo.testing 7from torch._higher_order_ops.associative_scan import associative_scan 8from torch._inductor.test_case import TestCase 9from torch.testing._internal.common_utils import ( 10 decorateIf, 11 instantiate_parametrized_tests, 12 parametrize, 13) 14from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU 15from torch.testing._internal.triton_utils import requires_gpu 16 17 18def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1): 19 result = [] 20 device = inputs[0].device 21 # iterate over the cartesian product of predicate values 22 for values in itertools.product(*([possible_values] * num_to_prepend)): 23 prepended = [torch.tensor(v, device=device) for v in values] 24 result.append((*prepended, *inputs)) 25 return result 26 27 28def prepend_predicates(inputs, num_predicates=1): 29 return _prepend_product_of_values(inputs, [False, True], num_predicates) 30 31 32def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)): 33 return _prepend_product_of_values(inputs, counter_values, num_counters) 34 35 36class CondModels: 37 class Simple(torch.nn.Module): 38 def forward(self, p, a, b): 39 def true_fn(x, y): 40 return x + y 41 42 def false_fn(x, y): 43 return x - y 44 45 return torch.cond(p, true_fn, false_fn, [a, b]) 46 47 class Nested(torch.nn.Module): 48 def forward(self, p0, p1, p2, a, b, c): 49 def true_fn(x0, y0, z0): 50 def true_true_fn(x1, y1, z1): 51 return (x1 - y1 * z1) * 3.14 52 53 def true_false_fn(x1, y1, z1): 54 def true_false_true_fn(x2, y2, z2): 55 return (x2 * y2 * z2) / 2.71 56 57 def true_false_false_fn(x2, y2, z2): 58 return (x2 + y2 + z2) * 1.23 59 60 return torch.cond( 61 p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1] 62 ) 63 64 return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0]) 65 66 def false_fn(x0, y0, z0): 67 def false_true_fn(x1, y1, z1): 68 def false_true_true_fn(x2, y2, z2): 69 return (x2 - y2 - z2) + 1.23 70 71 def false_true_false_fn(x2, y2, z2): 72 return (x2 / y2 / z2) - 3.14 73 74 return torch.cond( 75 p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1] 76 ) 77 78 def false_false_fn(x1, y1, z1): 79 return (x1 - y1 * z1) / 2.71 80 81 return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0]) 82 83 return torch.cond(p0, true_fn, false_fn, [a, b, c]) 84 85 class Parameters(torch.nn.Module): 86 class InnerModel1(torch.nn.Module): 87 def __init__(self, device): 88 super().__init__() 89 self.layer = torch.nn.Linear(20, 30, device=device) 90 91 def forward(self, x): 92 return self.layer(x + 1) * 3.14 93 94 class InnerModel2(torch.nn.Module): 95 def __init__(self, device): 96 super().__init__() 97 self.layer1 = torch.nn.Linear(20, 10, device=device) 98 self.layer2 = torch.nn.Linear(10, 30, device=device) 99 100 def forward(self, x): 101 return self.layer2(self.layer1(x - 2)) * 3.14 102 103 def __init__(self, device): 104 super().__init__() 105 self.true_fn = self.InnerModel1(device) 106 self.false_fn = self.InnerModel2(device) 107 108 def forward(self, p, a): 109 return torch.cond(p, self.true_fn, self.false_fn, [a]) 110 111 class ReinterpretView(torch.nn.Module): 112 def forward(self, p, a, b): 113 def true_fn(x, y): 114 z1 = x + y 115 z2 = x - y 116 return z1[2:], z2[:, 4:] 117 118 def false_fn(x, y): 119 z1 = x - y 120 z2 = x + y 121 return z1[2:], z2[:, 4:] 122 123 return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]]) 124 125 class MultipleOutputs(torch.nn.Module): 126 def forward(self, p, a, b, c): 127 def true_fn(x, y, z): 128 return x * y, z / 2.71, (y - x).sum(dim=1) 129 130 def false_fn(x, y, z): 131 return y / x, z * 3.14, (x + y).mean(dim=1) 132 133 return torch.cond(p, true_fn, false_fn, [a, b, c]) 134 135 class OuterCode(torch.nn.Module): 136 def forward(self, p, a, b): 137 c = a * b + 3.14 138 d = a / b - 2.71 139 140 def true_fn(x, y): 141 return x + y 142 143 def false_fn(x, y): 144 return x - y 145 146 e = torch.cond(p, true_fn, false_fn, [c, d]) 147 148 return e * e / 1.41 149 150 class OuterBuffers(torch.nn.Module): 151 def forward(self, p, a, b, c): 152 d = a * 2 153 e = b / 2 154 155 def true_fn(x): 156 return x + d 157 158 def false_fn(x): 159 return x - e 160 161 return torch.cond(p, true_fn, false_fn, [c]) 162 163 class WithNonTensorPredicate(torch.nn.Module): 164 def forward(self, a, b): 165 def true_fn(x, y): 166 return x.sum(0) / 3.14 167 168 def false_fn(x, y): 169 return y.sum(0) * 2.71 170 171 return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b]) 172 173 174class CondTests(TestCase): 175 def _run_test( 176 self, 177 model, 178 inputs, 179 device, 180 dynamic=False, 181 num_predicates=1, 182 ): 183 cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") 184 compiled_model = torch.compile(backend=cnt, fullgraph=True)(model) 185 186 inputs = [inp.to(device=device) for inp in inputs] 187 input_sets = [inputs] 188 if dynamic: 189 larger_inputs = [] 190 for inp in inputs: 191 # tile every first dim 5x 192 tiling = [5] + [1] * (inp.ndim - 1) 193 larger_inputs.append(torch.tile(inp, tiling)) 194 input_sets.append(larger_inputs) 195 for inputs in input_sets: 196 for inp in inputs: 197 # mark every first dim as dynamic 198 torch._dynamo.mark_dynamic(inp, 0) 199 200 for inputs in input_sets: 201 for inputs_with_predicates in prepend_predicates(inputs, num_predicates): 202 cloned_inputs = [inp.clone() for inp in inputs_with_predicates] 203 result = model(*inputs_with_predicates) 204 result_compiled = compiled_model(*inputs_with_predicates) 205 # inputs must not be mutated 206 torch.testing.assert_close(cloned_inputs, inputs_with_predicates) 207 torch.testing.assert_close(result, result_compiled) 208 209 self.assertEqual(cnt.frame_count, 1, "only one compilation expected") 210 211 @requires_gpu 212 @parametrize("device", ["cpu", GPU_TYPE]) 213 @parametrize("dynamic", [False, True]) 214 def test_cond_simple_control_flow(self, device, dynamic): 215 # cond control flow without nesting 216 self._run_test( 217 model=CondModels.Simple(), 218 inputs=( 219 torch.randn(10, 20), 220 torch.randn(10, 20), 221 ), 222 device=device, 223 dynamic=dynamic, 224 ) 225 226 @requires_gpu 227 def test_cond_control_flow_with_precomputed_size(self): 228 class TestModel(torch.nn.Module): 229 def __init__( 230 self, 231 ): 232 super().__init__() 233 self.conv2d = torch.nn.Conv2d( 234 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) 235 ) 236 self.threshold = 20 237 238 def forward(self, x: torch.Tensor, index) -> torch.Tensor: 239 def true_fn(x: torch.Tensor): 240 return self.conv2d(x) 241 242 def false_fn(x: torch.Tensor): 243 return self.conv2d(x) 244 245 return torch.cond( 246 index < self.threshold and index >= 0, true_fn, false_fn, (x,) 247 ) 248 249 main_model = TestModel().to(GPU_TYPE) 250 x1 = torch.rand(2, 512, 128, 72).to(GPU_TYPE) 251 x2 = torch.rand(2, 512, 96, 96).to(GPU_TYPE) 252 253 opt_model = torch.compile(main_model) 254 out1 = main_model(x1, 1) 255 opt_out1 = opt_model(x1, 1) 256 self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5)) 257 258 out2 = main_model(x2, 30) 259 opt_out2 = opt_model(x2, 30) 260 self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5)) 261 262 @requires_gpu 263 @parametrize("device", ["cpu", GPU_TYPE]) 264 @parametrize("dynamic", [False, True]) 265 def test_cond_nested_control_flow(self, device, dynamic): 266 # cond control flow with nesting 267 self._run_test( 268 model=CondModels.Nested(), 269 inputs=( 270 torch.randn(10, 20), 271 torch.randn(10, 20), 272 torch.randn(10, 20), 273 ), 274 device=device, 275 dynamic=dynamic, 276 num_predicates=3, 277 ) 278 279 @requires_gpu 280 @parametrize("device", ["cpu", GPU_TYPE]) 281 @parametrize("dynamic", [False, True]) 282 def test_cond_outer_code_before_after(self, device, dynamic): 283 # some code before and after the conditional 284 self._run_test( 285 model=CondModels.OuterCode(), 286 inputs=( 287 torch.randn(10, 20), 288 torch.randn(10, 20), 289 ), 290 device=device, 291 dynamic=dynamic, 292 ) 293 294 @requires_gpu 295 @parametrize("device", ["cpu", GPU_TYPE]) 296 @parametrize("dynamic", [False, True]) 297 def test_cond_multiple_outputs(self, device, dynamic): 298 # multiple outputs with different shapes 299 self._run_test( 300 model=CondModels.MultipleOutputs(), 301 inputs=( 302 torch.randn(10, 20), 303 torch.randn(10, 20), 304 torch.randn(30, 40), 305 ), 306 device=device, 307 dynamic=dynamic, 308 ) 309 310 @requires_gpu 311 @parametrize("device", ["cpu", GPU_TYPE]) 312 def test_cond_advanced_dynamic_shapes(self, device): 313 # subgraphs input shapes include symbolic expressions 314 class Model(torch.nn.Module): 315 def forward(self, p, a, b): 316 def true_fn(x, y): 317 return torch.cat([x - 3, y * 3], dim=1) 318 319 def false_fn(x, y): 320 return torch.cat([x / 3, y - 3], dim=1) 321 322 c = torch.cat([a, b], dim=0) 323 d = c * 2 324 e = c / 2 325 326 return torch.cond(p, true_fn, false_fn, [d, e]) 327 328 self._run_test( 329 model=Model(), 330 inputs=( 331 torch.randn(2, 3, 3), 332 torch.randn(4, 3, 3), 333 ), 334 device=device, 335 dynamic=True, 336 ) 337 338 @requires_gpu 339 def test_cond_use_buffers_from_outer_scope(self): 340 # subgraphs input shapes include symbolic expressions 341 self._run_test( 342 model=CondModels.OuterBuffers(), 343 inputs=( 344 torch.randn(10, 20), 345 torch.randn(10, 20), 346 torch.randn(10, 20), 347 ), 348 device=GPU_TYPE, 349 dynamic=False, 350 ) 351 352 @requires_gpu 353 def test_cond_reintepret_view_inputs_outputs(self): 354 # ReinterpretView in inputs and outputs of the subgraphs 355 self._run_test( 356 model=CondModels.ReinterpretView(), 357 inputs=( 358 torch.randn(10, 20), 359 torch.randn(10, 20), 360 ), 361 device=GPU_TYPE, 362 dynamic=True, 363 ) 364 365 @requires_gpu 366 @parametrize("device", ["cpu", GPU_TYPE]) 367 @parametrize("dynamic", [False, True]) 368 def test_cond_subgraphs_with_parameters(self, device, dynamic): 369 # nested Modules with parameters 370 self._run_test( 371 model=CondModels.Parameters(device), 372 inputs=(torch.randn(10, 20),), 373 device=device, 374 dynamic=dynamic, 375 ) 376 377 @requires_gpu 378 @parametrize("device", ["cpu", GPU_TYPE]) 379 @parametrize("dynamic", [False, True]) 380 def test_cond_non_tensor_predicates(self, device, dynamic): 381 # model with a boolean predicate 382 for b_size_0 in [5, 15]: 383 torch._dynamo.reset() 384 self._run_test( 385 model=CondModels.WithNonTensorPredicate(), 386 inputs=( 387 torch.randn(10, 20), 388 torch.randn(b_size_0, 20), 389 ), 390 device=device, 391 dynamic=dynamic, 392 num_predicates=0, 393 ) 394 395 @requires_gpu 396 def test_cond_aliasing_outputs(self): 397 # output aliasing in subgraphs: not supported 398 class Model(torch.nn.Module): 399 def forward(self, p, a, b): 400 def true_fn(x, y): 401 z = x + y 402 return z, z[1:] 403 404 def false_fn(x, y): 405 z = x - y 406 return z, z[1:] 407 408 return torch.cond(p, true_fn, false_fn, [a, b]) 409 410 # AssertionError: Output aliasing is currently not supported... 411 with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed): 412 torch.compile(Model())( 413 torch.tensor(True), 414 torch.randn(10, 20), 415 torch.randn(10, 20), 416 ) 417 418 @requires_gpu 419 @parametrize("device", ["cpu", GPU_TYPE]) 420 def test_cond_decompose_ops_in_subgraph(self, device): 421 class Model(torch.nn.Module): 422 def forward(self, p, a): 423 def true_fn(x): 424 return torch.zeros_like(x) 425 426 def false_fn(x): 427 return torch.ones_like(x) 428 429 b = torch.ones_like(a) 430 c = torch.cond(p, true_fn, false_fn, [b]) 431 return c 432 433 self._run_test( 434 model=Model(), 435 inputs=(torch.rand(10, 20),), 436 device=device, 437 ) 438 439 @requires_gpu 440 @parametrize("device", ["cpu", GPU_TYPE]) 441 def test_cond_decompose_ops_in_subgraph_recursive(self, device): 442 def inner_fn1(x): 443 return torch.zeros_like(x) 444 445 def inner_fn2(x): 446 return torch.ones_like(x) 447 448 class Model(torch.nn.Module): 449 def forward(self, p, a): 450 def true_fn(x): 451 return torch.cond(p, inner_fn2, inner_fn1, [x]) 452 453 def false_fn(x): 454 return torch.cond(p, inner_fn1, inner_fn2, [x]) 455 456 b = torch.ones_like(a) 457 c = torch.cond(p, true_fn, false_fn, [b]) 458 return c 459 460 self._run_test( 461 model=Model(), 462 inputs=(torch.rand(10, 20),), 463 device=device, 464 ) 465 466 @requires_gpu 467 def test_cond_inductor_fx_passes_recursively_applied(self): 468 counters = {"pre_grad": 0, "post_grad": 0} 469 470 def pre_grad_pass_counter(gm): 471 counters["pre_grad"] += 1 472 473 def post_grad_pass_counter(gm): 474 counters["post_grad"] += 1 475 476 with torch._inductor.config.patch( 477 { 478 "pre_grad_custom_pass": pre_grad_pass_counter, 479 "post_grad_custom_pre_pass": post_grad_pass_counter, 480 # The above patches don't pickle 481 "fx_graph_cache": False, 482 } 483 ): 484 self._run_test( 485 model=CondModels.Nested(), 486 inputs=( 487 torch.randn(10, 20), 488 torch.randn(10, 20), 489 torch.randn(10, 20), 490 ), 491 device=GPU_TYPE, 492 dynamic=True, 493 num_predicates=3, 494 ) 495 496 self.assertEqual(counters["pre_grad"], 11) 497 self.assertEqual(counters["post_grad"], 11) 498 499 500class WhileLoopModels: 501 class Simple(torch.nn.Module): 502 def forward(self, ci, a, b): 503 def cond_fn(i, x, y): 504 return i > 0 505 506 def body_fn(i, x, y): 507 return i - 1, x + y, y - x 508 509 return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b]) 510 511 class Nested(torch.nn.Module): 512 def forward(self, ci, cj, a, b): 513 def cond_fn(i1, j1, x1, y1): 514 return i1 > 0 515 516 def body_fn(i1, j1, x1, y1): 517 def cond_fn_nested(i2, j2, x2, y2): 518 return j2 > 0 519 520 def body_fn_nested(i2, j2, x2, y2): 521 return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71 522 523 i1, j1, x1, y1 = torch._higher_order_ops.while_loop( 524 cond_fn_nested, body_fn_nested, [i1, j1, x1, y1] 525 ) 526 527 return i1 - 1, j1.clone(), x1 * 2, y1 / 2 528 529 return torch._higher_order_ops.while_loop(cond_fn, body_fn, (ci, cj, a, b)) 530 531 class Parameters(torch.nn.Module): 532 class InnerModel(torch.nn.Module): 533 def __init__(self, device): 534 super().__init__() 535 self.layer1 = torch.nn.Linear(20, 30, device=device) 536 self.layer2 = torch.nn.Linear(30, 20, device=device) 537 538 def forward(self, c, x): 539 return c - 1, self.layer2(self.layer1(x - 2)) * 3.14 540 541 def __init__(self, device): 542 super().__init__() 543 self.body_fn = self.InnerModel(device) 544 self.cond_fn = lambda c, x: c > 0 545 546 def forward(self, c, a): 547 return torch._higher_order_ops.while_loop( 548 self.cond_fn, self.body_fn, [c, a] 549 ) 550 551 class OuterCode(torch.nn.Module): 552 def forward(self, c, a, b): 553 d = a * b + 3.14 554 e = a / b - 2.71 555 556 def cond_fn(c, x, y): 557 return c > 0 558 559 def body_fn(c, x, y): 560 return c - 1, y - x, x + y 561 562 _, f, g = torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, d, e]) 563 564 return f * g / 1.41 565 566 # TODO(aakhundov): add while_loop test with outer buffers 567 # with dynamic=True once dynamo / export allows while_loop 568 # closure capture with mark_dynamic: 569 # https://github.com/pytorch/pytorch/issues/123596 570 class OuterBuffers(torch.nn.Module): 571 def forward(self, c, a, b): 572 d = a * 2 573 e = b / 2 574 575 def cond_fn(c, x, y): 576 return c > 0 577 578 def body_fn(c, x, y): 579 return c - 1, x + d, y - e 580 581 return torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, a, b]) 582 583 584class WhileLoopTests(TestCase): 585 def _run_test( 586 self, 587 model, 588 inputs, 589 device, 590 dynamic=False, 591 num_counters=1, 592 ): 593 cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") 594 compiled_model = torch.compile(backend=cnt, fullgraph=True)(model) 595 596 inputs = [inp.to(device=device) for inp in inputs] 597 input_sets = [inputs] 598 if dynamic: 599 larger_inputs = [] 600 for inp in inputs: 601 # tile every first dim 5x 602 tiling = [5] + [1] * (inp.ndim - 1) 603 larger_inputs.append(torch.tile(inp, tiling)) 604 input_sets.append(larger_inputs) 605 for inputs in input_sets: 606 for inp in inputs: 607 # mark every first dim as dynamic 608 if inp.ndim: 609 torch._dynamo.mark_dynamic(inp, 0) 610 611 for inputs in input_sets: 612 for inputs_with_counters in prepend_counters(inputs, num_counters): 613 cloned_inputs = [inp.clone() for inp in inputs_with_counters] 614 result = model(*inputs_with_counters) 615 with torch.no_grad(): 616 result_compiled = compiled_model(*inputs_with_counters) 617 # inputs must not be mutated 618 torch.testing.assert_close(cloned_inputs, inputs_with_counters) 619 torch.testing.assert_close( 620 result, result_compiled, atol=1e-4, rtol=1e-4 621 ) 622 623 self.assertEqual(cnt.frame_count, 1, "only one compilation expected") 624 625 @requires_gpu 626 @parametrize("device", ["cpu", GPU_TYPE]) 627 @parametrize("dynamic", [False, True]) 628 def test_while_loop_simple_control_flow(self, device, dynamic): 629 # while_loop control flow without nesting 630 self._run_test( 631 model=WhileLoopModels.Simple(), 632 inputs=( 633 torch.randn(10, 20), 634 torch.randn(10, 20), 635 ), 636 device=device, 637 dynamic=dynamic, 638 ) 639 640 @requires_gpu 641 @parametrize("device", ["cpu", GPU_TYPE]) 642 @parametrize("dynamic", [False, True]) 643 def test_while_loop_nested_control_flow(self, device, dynamic): 644 # while_loop control flow with nesting 645 self._run_test( 646 model=WhileLoopModels.Nested(), 647 inputs=( 648 torch.randn(10, 20), 649 torch.randn(10, 20), 650 ), 651 device=device, 652 dynamic=dynamic, 653 num_counters=2, 654 ) 655 656 @requires_gpu 657 @parametrize("device", ["cpu", GPU_TYPE]) 658 @parametrize("dynamic", [False, True]) 659 def test_while_loop_with_outer_code(self, device, dynamic): 660 # while_loop control flow with outer code 661 self._run_test( 662 model=WhileLoopModels.OuterCode(), 663 inputs=( 664 torch.randn(10, 20), 665 torch.randn(10, 20), 666 ), 667 device=device, 668 dynamic=dynamic, 669 ) 670 671 @requires_gpu 672 @parametrize("device", ["cpu", GPU_TYPE]) 673 @parametrize("dynamic", [False, True]) 674 def test_while_loop_with_parameters(self, device, dynamic): 675 # while_loop control flow with parameters 676 self._run_test( 677 model=WhileLoopModels.Parameters(device), 678 inputs=(torch.randn(10, 20),), 679 device=device, 680 dynamic=dynamic, 681 ) 682 683 @requires_gpu 684 @parametrize("device", ["cpu", GPU_TYPE]) 685 # dynamic=True doesn't work now due to 686 # https://github.com/pytorch/pytorch/issues/123596 687 @parametrize("dynamic", [False]) 688 def test_while_loop_with_outer_buffers(self, device, dynamic): 689 # while_loop control flow with outer code 690 self._run_test( 691 model=WhileLoopModels.OuterBuffers(), 692 inputs=( 693 torch.randn(10, 20), 694 torch.randn(10, 20), 695 ), 696 device=device, 697 dynamic=dynamic, 698 ) 699 700 701class AssociativeScanTests(TestCase): 702 @requires_gpu 703 @parametrize("combine_mode", ["pointwise", "generic"]) 704 @parametrize("backend", ["inductor"]) 705 @parametrize("device", [torch.device("cpu"), GPU_TYPE]) 706 # This test will fail as flip in combination with particular input lenghts 707 # produces weird results. 708 # This is under investigations in 709 # https://github.com/pytorch/pytorch/issues/131805 710 @decorateIf(unittest.skip, lambda params: params["device"] == GPU_TYPE) 711 def test_associative_scan_CUDA_flip(self, combine_mode, backend, device): 712 def fct(x: torch.Tensor, y: torch.Tensor): 713 return x + y 714 715 for n in range(10): 716 x = torch.arange(n, device=device) 717 torch.compiler.reset() 718 associative_scan1 = torch.compile( 719 associative_scan, backend=backend, fullgraph=True 720 ) 721 associative_scan2 = associative_scan 722 723 if combine_mode == "pointwise" and device == torch.device("cpu"): 724 with self.assertRaisesRegex(Exception, r"."): 725 associative_scan1( 726 fct, x, 0, reverse=False, combine_mode=combine_mode 727 ) 728 729 # Skipping test because combine_mode currently only suppors CUDA tensors 730 return 731 732 result1 = associative_scan1( 733 fct, x, 0, reverse=False, combine_mode=combine_mode 734 ) 735 result2 = associative_scan2( 736 fct, x, 0, reverse=False, combine_mode=combine_mode 737 ) 738 result3 = torch.cumsum(x, 0) 739 740 self.assertEqual(result1, result2) 741 self.assertEqual(result1, result3) 742 743 # Flip only non-compiled and compare with compiled reverse=True 744 result1 = associative_scan1( 745 fct, x, 0, reverse=True, combine_mode=combine_mode 746 ) 747 result2 = torch.flip( 748 associative_scan2( 749 fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode 750 ), 751 [0], 752 ) 753 result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) 754 755 self.assertEqual(result1, result2) 756 self.assertEqual(result1, result3) 757 758 # Flip only compiled and compare with non-compiled reverse=True 759 result1 = torch.flip( 760 associative_scan1( 761 fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode 762 ), 763 [0], 764 ) 765 result2 = associative_scan2( 766 fct, x, 0, reverse=True, combine_mode=combine_mode 767 ) 768 result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) 769 770 self.assertEqual(result1, result2) 771 self.assertEqual(result1, result3) 772 773 # Use reverse=False, but flip both results before and after 774 result1 = torch.flip( 775 associative_scan1( 776 fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode 777 ), 778 [0], 779 ) 780 result2 = torch.flip( 781 associative_scan2( 782 fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode 783 ), 784 [0], 785 ) 786 result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) 787 788 self.assertEqual(result1, result2) 789 self.assertEqual(result1, result3) 790 791 # Reverse=True 792 result1 = associative_scan1( 793 fct, x, 0, reverse=True, combine_mode=combine_mode 794 ) 795 result2 = associative_scan2( 796 fct, x, 0, reverse=True, combine_mode=combine_mode 797 ) 798 result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]) 799 800 self.assertEqual(result1, result2) 801 self.assertEqual(result1, result3) 802 803 804instantiate_parametrized_tests(CondTests) 805instantiate_parametrized_tests(WhileLoopTests) 806instantiate_parametrized_tests(AssociativeScanTests) 807 808 809if __name__ == "__main__": 810 from torch._inductor.test_case import run_tests 811 812 if HAS_CPU or HAS_GPU: 813 run_tests(needs="filelock") 814