xref: /aosp_15_r20/external/pytorch/test/inductor/test_control_flow.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import itertools
3import unittest
4
5import torch
6import torch._dynamo.testing
7from torch._higher_order_ops.associative_scan import associative_scan
8from torch._inductor.test_case import TestCase
9from torch.testing._internal.common_utils import (
10    decorateIf,
11    instantiate_parametrized_tests,
12    parametrize,
13)
14from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
15from torch.testing._internal.triton_utils import requires_gpu
16
17
18def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
19    result = []
20    device = inputs[0].device
21    # iterate over the cartesian product of predicate values
22    for values in itertools.product(*([possible_values] * num_to_prepend)):
23        prepended = [torch.tensor(v, device=device) for v in values]
24        result.append((*prepended, *inputs))
25    return result
26
27
28def prepend_predicates(inputs, num_predicates=1):
29    return _prepend_product_of_values(inputs, [False, True], num_predicates)
30
31
32def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
33    return _prepend_product_of_values(inputs, counter_values, num_counters)
34
35
36class CondModels:
37    class Simple(torch.nn.Module):
38        def forward(self, p, a, b):
39            def true_fn(x, y):
40                return x + y
41
42            def false_fn(x, y):
43                return x - y
44
45            return torch.cond(p, true_fn, false_fn, [a, b])
46
47    class Nested(torch.nn.Module):
48        def forward(self, p0, p1, p2, a, b, c):
49            def true_fn(x0, y0, z0):
50                def true_true_fn(x1, y1, z1):
51                    return (x1 - y1 * z1) * 3.14
52
53                def true_false_fn(x1, y1, z1):
54                    def true_false_true_fn(x2, y2, z2):
55                        return (x2 * y2 * z2) / 2.71
56
57                    def true_false_false_fn(x2, y2, z2):
58                        return (x2 + y2 + z2) * 1.23
59
60                    return torch.cond(
61                        p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
62                    )
63
64                return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
65
66            def false_fn(x0, y0, z0):
67                def false_true_fn(x1, y1, z1):
68                    def false_true_true_fn(x2, y2, z2):
69                        return (x2 - y2 - z2) + 1.23
70
71                    def false_true_false_fn(x2, y2, z2):
72                        return (x2 / y2 / z2) - 3.14
73
74                    return torch.cond(
75                        p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
76                    )
77
78                def false_false_fn(x1, y1, z1):
79                    return (x1 - y1 * z1) / 2.71
80
81                return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
82
83            return torch.cond(p0, true_fn, false_fn, [a, b, c])
84
85    class Parameters(torch.nn.Module):
86        class InnerModel1(torch.nn.Module):
87            def __init__(self, device):
88                super().__init__()
89                self.layer = torch.nn.Linear(20, 30, device=device)
90
91            def forward(self, x):
92                return self.layer(x + 1) * 3.14
93
94        class InnerModel2(torch.nn.Module):
95            def __init__(self, device):
96                super().__init__()
97                self.layer1 = torch.nn.Linear(20, 10, device=device)
98                self.layer2 = torch.nn.Linear(10, 30, device=device)
99
100            def forward(self, x):
101                return self.layer2(self.layer1(x - 2)) * 3.14
102
103        def __init__(self, device):
104            super().__init__()
105            self.true_fn = self.InnerModel1(device)
106            self.false_fn = self.InnerModel2(device)
107
108        def forward(self, p, a):
109            return torch.cond(p, self.true_fn, self.false_fn, [a])
110
111    class ReinterpretView(torch.nn.Module):
112        def forward(self, p, a, b):
113            def true_fn(x, y):
114                z1 = x + y
115                z2 = x - y
116                return z1[2:], z2[:, 4:]
117
118            def false_fn(x, y):
119                z1 = x - y
120                z2 = x + y
121                return z1[2:], z2[:, 4:]
122
123            return torch.cond(p, true_fn, false_fn, [a[:-1], b[:-1]])
124
125    class MultipleOutputs(torch.nn.Module):
126        def forward(self, p, a, b, c):
127            def true_fn(x, y, z):
128                return x * y, z / 2.71, (y - x).sum(dim=1)
129
130            def false_fn(x, y, z):
131                return y / x, z * 3.14, (x + y).mean(dim=1)
132
133            return torch.cond(p, true_fn, false_fn, [a, b, c])
134
135    class OuterCode(torch.nn.Module):
136        def forward(self, p, a, b):
137            c = a * b + 3.14
138            d = a / b - 2.71
139
140            def true_fn(x, y):
141                return x + y
142
143            def false_fn(x, y):
144                return x - y
145
146            e = torch.cond(p, true_fn, false_fn, [c, d])
147
148            return e * e / 1.41
149
150    class OuterBuffers(torch.nn.Module):
151        def forward(self, p, a, b, c):
152            d = a * 2
153            e = b / 2
154
155            def true_fn(x):
156                return x + d
157
158            def false_fn(x):
159                return x - e
160
161            return torch.cond(p, true_fn, false_fn, [c])
162
163    class WithNonTensorPredicate(torch.nn.Module):
164        def forward(self, a, b):
165            def true_fn(x, y):
166                return x.sum(0) / 3.14
167
168            def false_fn(x, y):
169                return y.sum(0) * 2.71
170
171            return torch.cond(a.size(0) > b.size(0), true_fn, false_fn, [a, b])
172
173
174class CondTests(TestCase):
175    def _run_test(
176        self,
177        model,
178        inputs,
179        device,
180        dynamic=False,
181        num_predicates=1,
182    ):
183        cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
184        compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
185
186        inputs = [inp.to(device=device) for inp in inputs]
187        input_sets = [inputs]
188        if dynamic:
189            larger_inputs = []
190            for inp in inputs:
191                # tile every first dim 5x
192                tiling = [5] + [1] * (inp.ndim - 1)
193                larger_inputs.append(torch.tile(inp, tiling))
194            input_sets.append(larger_inputs)
195            for inputs in input_sets:
196                for inp in inputs:
197                    # mark every first dim as dynamic
198                    torch._dynamo.mark_dynamic(inp, 0)
199
200        for inputs in input_sets:
201            for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
202                cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
203                result = model(*inputs_with_predicates)
204                result_compiled = compiled_model(*inputs_with_predicates)
205                # inputs must not be mutated
206                torch.testing.assert_close(cloned_inputs, inputs_with_predicates)
207                torch.testing.assert_close(result, result_compiled)
208
209        self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
210
211    @requires_gpu
212    @parametrize("device", ["cpu", GPU_TYPE])
213    @parametrize("dynamic", [False, True])
214    def test_cond_simple_control_flow(self, device, dynamic):
215        # cond control flow without nesting
216        self._run_test(
217            model=CondModels.Simple(),
218            inputs=(
219                torch.randn(10, 20),
220                torch.randn(10, 20),
221            ),
222            device=device,
223            dynamic=dynamic,
224        )
225
226    @requires_gpu
227    def test_cond_control_flow_with_precomputed_size(self):
228        class TestModel(torch.nn.Module):
229            def __init__(
230                self,
231            ):
232                super().__init__()
233                self.conv2d = torch.nn.Conv2d(
234                    512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
235                )
236                self.threshold = 20
237
238            def forward(self, x: torch.Tensor, index) -> torch.Tensor:
239                def true_fn(x: torch.Tensor):
240                    return self.conv2d(x)
241
242                def false_fn(x: torch.Tensor):
243                    return self.conv2d(x)
244
245                return torch.cond(
246                    index < self.threshold and index >= 0, true_fn, false_fn, (x,)
247                )
248
249        main_model = TestModel().to(GPU_TYPE)
250        x1 = torch.rand(2, 512, 128, 72).to(GPU_TYPE)
251        x2 = torch.rand(2, 512, 96, 96).to(GPU_TYPE)
252
253        opt_model = torch.compile(main_model)
254        out1 = main_model(x1, 1)
255        opt_out1 = opt_model(x1, 1)
256        self.assertTrue(torch.allclose(out1, opt_out1, atol=1e-5))
257
258        out2 = main_model(x2, 30)
259        opt_out2 = opt_model(x2, 30)
260        self.assertTrue(torch.allclose(out2, opt_out2, atol=1e-5))
261
262    @requires_gpu
263    @parametrize("device", ["cpu", GPU_TYPE])
264    @parametrize("dynamic", [False, True])
265    def test_cond_nested_control_flow(self, device, dynamic):
266        # cond control flow with nesting
267        self._run_test(
268            model=CondModels.Nested(),
269            inputs=(
270                torch.randn(10, 20),
271                torch.randn(10, 20),
272                torch.randn(10, 20),
273            ),
274            device=device,
275            dynamic=dynamic,
276            num_predicates=3,
277        )
278
279    @requires_gpu
280    @parametrize("device", ["cpu", GPU_TYPE])
281    @parametrize("dynamic", [False, True])
282    def test_cond_outer_code_before_after(self, device, dynamic):
283        # some code before and after the conditional
284        self._run_test(
285            model=CondModels.OuterCode(),
286            inputs=(
287                torch.randn(10, 20),
288                torch.randn(10, 20),
289            ),
290            device=device,
291            dynamic=dynamic,
292        )
293
294    @requires_gpu
295    @parametrize("device", ["cpu", GPU_TYPE])
296    @parametrize("dynamic", [False, True])
297    def test_cond_multiple_outputs(self, device, dynamic):
298        # multiple outputs with different shapes
299        self._run_test(
300            model=CondModels.MultipleOutputs(),
301            inputs=(
302                torch.randn(10, 20),
303                torch.randn(10, 20),
304                torch.randn(30, 40),
305            ),
306            device=device,
307            dynamic=dynamic,
308        )
309
310    @requires_gpu
311    @parametrize("device", ["cpu", GPU_TYPE])
312    def test_cond_advanced_dynamic_shapes(self, device):
313        # subgraphs input shapes include symbolic expressions
314        class Model(torch.nn.Module):
315            def forward(self, p, a, b):
316                def true_fn(x, y):
317                    return torch.cat([x - 3, y * 3], dim=1)
318
319                def false_fn(x, y):
320                    return torch.cat([x / 3, y - 3], dim=1)
321
322                c = torch.cat([a, b], dim=0)
323                d = c * 2
324                e = c / 2
325
326                return torch.cond(p, true_fn, false_fn, [d, e])
327
328        self._run_test(
329            model=Model(),
330            inputs=(
331                torch.randn(2, 3, 3),
332                torch.randn(4, 3, 3),
333            ),
334            device=device,
335            dynamic=True,
336        )
337
338    @requires_gpu
339    def test_cond_use_buffers_from_outer_scope(self):
340        # subgraphs input shapes include symbolic expressions
341        self._run_test(
342            model=CondModels.OuterBuffers(),
343            inputs=(
344                torch.randn(10, 20),
345                torch.randn(10, 20),
346                torch.randn(10, 20),
347            ),
348            device=GPU_TYPE,
349            dynamic=False,
350        )
351
352    @requires_gpu
353    def test_cond_reintepret_view_inputs_outputs(self):
354        # ReinterpretView in inputs and outputs of the subgraphs
355        self._run_test(
356            model=CondModels.ReinterpretView(),
357            inputs=(
358                torch.randn(10, 20),
359                torch.randn(10, 20),
360            ),
361            device=GPU_TYPE,
362            dynamic=True,
363        )
364
365    @requires_gpu
366    @parametrize("device", ["cpu", GPU_TYPE])
367    @parametrize("dynamic", [False, True])
368    def test_cond_subgraphs_with_parameters(self, device, dynamic):
369        # nested Modules with parameters
370        self._run_test(
371            model=CondModels.Parameters(device),
372            inputs=(torch.randn(10, 20),),
373            device=device,
374            dynamic=dynamic,
375        )
376
377    @requires_gpu
378    @parametrize("device", ["cpu", GPU_TYPE])
379    @parametrize("dynamic", [False, True])
380    def test_cond_non_tensor_predicates(self, device, dynamic):
381        # model with a boolean predicate
382        for b_size_0 in [5, 15]:
383            torch._dynamo.reset()
384            self._run_test(
385                model=CondModels.WithNonTensorPredicate(),
386                inputs=(
387                    torch.randn(10, 20),
388                    torch.randn(b_size_0, 20),
389                ),
390                device=device,
391                dynamic=dynamic,
392                num_predicates=0,
393            )
394
395    @requires_gpu
396    def test_cond_aliasing_outputs(self):
397        # output aliasing in subgraphs: not supported
398        class Model(torch.nn.Module):
399            def forward(self, p, a, b):
400                def true_fn(x, y):
401                    z = x + y
402                    return z, z[1:]
403
404                def false_fn(x, y):
405                    z = x - y
406                    return z, z[1:]
407
408                return torch.cond(p, true_fn, false_fn, [a, b])
409
410        # AssertionError: Output aliasing is currently not supported...
411        with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed):
412            torch.compile(Model())(
413                torch.tensor(True),
414                torch.randn(10, 20),
415                torch.randn(10, 20),
416            )
417
418    @requires_gpu
419    @parametrize("device", ["cpu", GPU_TYPE])
420    def test_cond_decompose_ops_in_subgraph(self, device):
421        class Model(torch.nn.Module):
422            def forward(self, p, a):
423                def true_fn(x):
424                    return torch.zeros_like(x)
425
426                def false_fn(x):
427                    return torch.ones_like(x)
428
429                b = torch.ones_like(a)
430                c = torch.cond(p, true_fn, false_fn, [b])
431                return c
432
433        self._run_test(
434            model=Model(),
435            inputs=(torch.rand(10, 20),),
436            device=device,
437        )
438
439    @requires_gpu
440    @parametrize("device", ["cpu", GPU_TYPE])
441    def test_cond_decompose_ops_in_subgraph_recursive(self, device):
442        def inner_fn1(x):
443            return torch.zeros_like(x)
444
445        def inner_fn2(x):
446            return torch.ones_like(x)
447
448        class Model(torch.nn.Module):
449            def forward(self, p, a):
450                def true_fn(x):
451                    return torch.cond(p, inner_fn2, inner_fn1, [x])
452
453                def false_fn(x):
454                    return torch.cond(p, inner_fn1, inner_fn2, [x])
455
456                b = torch.ones_like(a)
457                c = torch.cond(p, true_fn, false_fn, [b])
458                return c
459
460        self._run_test(
461            model=Model(),
462            inputs=(torch.rand(10, 20),),
463            device=device,
464        )
465
466    @requires_gpu
467    def test_cond_inductor_fx_passes_recursively_applied(self):
468        counters = {"pre_grad": 0, "post_grad": 0}
469
470        def pre_grad_pass_counter(gm):
471            counters["pre_grad"] += 1
472
473        def post_grad_pass_counter(gm):
474            counters["post_grad"] += 1
475
476        with torch._inductor.config.patch(
477            {
478                "pre_grad_custom_pass": pre_grad_pass_counter,
479                "post_grad_custom_pre_pass": post_grad_pass_counter,
480                # The above patches don't pickle
481                "fx_graph_cache": False,
482            }
483        ):
484            self._run_test(
485                model=CondModels.Nested(),
486                inputs=(
487                    torch.randn(10, 20),
488                    torch.randn(10, 20),
489                    torch.randn(10, 20),
490                ),
491                device=GPU_TYPE,
492                dynamic=True,
493                num_predicates=3,
494            )
495
496        self.assertEqual(counters["pre_grad"], 11)
497        self.assertEqual(counters["post_grad"], 11)
498
499
500class WhileLoopModels:
501    class Simple(torch.nn.Module):
502        def forward(self, ci, a, b):
503            def cond_fn(i, x, y):
504                return i > 0
505
506            def body_fn(i, x, y):
507                return i - 1, x + y, y - x
508
509            return torch._higher_order_ops.while_loop(cond_fn, body_fn, [ci, a, b])
510
511    class Nested(torch.nn.Module):
512        def forward(self, ci, cj, a, b):
513            def cond_fn(i1, j1, x1, y1):
514                return i1 > 0
515
516            def body_fn(i1, j1, x1, y1):
517                def cond_fn_nested(i2, j2, x2, y2):
518                    return j2 > 0
519
520                def body_fn_nested(i2, j2, x2, y2):
521                    return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
522
523                i1, j1, x1, y1 = torch._higher_order_ops.while_loop(
524                    cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
525                )
526
527                return i1 - 1, j1.clone(), x1 * 2, y1 / 2
528
529            return torch._higher_order_ops.while_loop(cond_fn, body_fn, (ci, cj, a, b))
530
531    class Parameters(torch.nn.Module):
532        class InnerModel(torch.nn.Module):
533            def __init__(self, device):
534                super().__init__()
535                self.layer1 = torch.nn.Linear(20, 30, device=device)
536                self.layer2 = torch.nn.Linear(30, 20, device=device)
537
538            def forward(self, c, x):
539                return c - 1, self.layer2(self.layer1(x - 2)) * 3.14
540
541        def __init__(self, device):
542            super().__init__()
543            self.body_fn = self.InnerModel(device)
544            self.cond_fn = lambda c, x: c > 0
545
546        def forward(self, c, a):
547            return torch._higher_order_ops.while_loop(
548                self.cond_fn, self.body_fn, [c, a]
549            )
550
551    class OuterCode(torch.nn.Module):
552        def forward(self, c, a, b):
553            d = a * b + 3.14
554            e = a / b - 2.71
555
556            def cond_fn(c, x, y):
557                return c > 0
558
559            def body_fn(c, x, y):
560                return c - 1, y - x, x + y
561
562            _, f, g = torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, d, e])
563
564            return f * g / 1.41
565
566    # TODO(aakhundov): add while_loop test with outer buffers
567    # with dynamic=True once dynamo / export allows while_loop
568    # closure capture with mark_dynamic:
569    # https://github.com/pytorch/pytorch/issues/123596
570    class OuterBuffers(torch.nn.Module):
571        def forward(self, c, a, b):
572            d = a * 2
573            e = b / 2
574
575            def cond_fn(c, x, y):
576                return c > 0
577
578            def body_fn(c, x, y):
579                return c - 1, x + d, y - e
580
581            return torch._higher_order_ops.while_loop(cond_fn, body_fn, [c, a, b])
582
583
584class WhileLoopTests(TestCase):
585    def _run_test(
586        self,
587        model,
588        inputs,
589        device,
590        dynamic=False,
591        num_counters=1,
592    ):
593        cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
594        compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
595
596        inputs = [inp.to(device=device) for inp in inputs]
597        input_sets = [inputs]
598        if dynamic:
599            larger_inputs = []
600            for inp in inputs:
601                # tile every first dim 5x
602                tiling = [5] + [1] * (inp.ndim - 1)
603                larger_inputs.append(torch.tile(inp, tiling))
604            input_sets.append(larger_inputs)
605            for inputs in input_sets:
606                for inp in inputs:
607                    # mark every first dim as dynamic
608                    if inp.ndim:
609                        torch._dynamo.mark_dynamic(inp, 0)
610
611        for inputs in input_sets:
612            for inputs_with_counters in prepend_counters(inputs, num_counters):
613                cloned_inputs = [inp.clone() for inp in inputs_with_counters]
614                result = model(*inputs_with_counters)
615                with torch.no_grad():
616                    result_compiled = compiled_model(*inputs_with_counters)
617                # inputs must not be mutated
618                torch.testing.assert_close(cloned_inputs, inputs_with_counters)
619                torch.testing.assert_close(
620                    result, result_compiled, atol=1e-4, rtol=1e-4
621                )
622
623        self.assertEqual(cnt.frame_count, 1, "only one compilation expected")
624
625    @requires_gpu
626    @parametrize("device", ["cpu", GPU_TYPE])
627    @parametrize("dynamic", [False, True])
628    def test_while_loop_simple_control_flow(self, device, dynamic):
629        # while_loop control flow without nesting
630        self._run_test(
631            model=WhileLoopModels.Simple(),
632            inputs=(
633                torch.randn(10, 20),
634                torch.randn(10, 20),
635            ),
636            device=device,
637            dynamic=dynamic,
638        )
639
640    @requires_gpu
641    @parametrize("device", ["cpu", GPU_TYPE])
642    @parametrize("dynamic", [False, True])
643    def test_while_loop_nested_control_flow(self, device, dynamic):
644        # while_loop control flow with nesting
645        self._run_test(
646            model=WhileLoopModels.Nested(),
647            inputs=(
648                torch.randn(10, 20),
649                torch.randn(10, 20),
650            ),
651            device=device,
652            dynamic=dynamic,
653            num_counters=2,
654        )
655
656    @requires_gpu
657    @parametrize("device", ["cpu", GPU_TYPE])
658    @parametrize("dynamic", [False, True])
659    def test_while_loop_with_outer_code(self, device, dynamic):
660        # while_loop control flow with outer code
661        self._run_test(
662            model=WhileLoopModels.OuterCode(),
663            inputs=(
664                torch.randn(10, 20),
665                torch.randn(10, 20),
666            ),
667            device=device,
668            dynamic=dynamic,
669        )
670
671    @requires_gpu
672    @parametrize("device", ["cpu", GPU_TYPE])
673    @parametrize("dynamic", [False, True])
674    def test_while_loop_with_parameters(self, device, dynamic):
675        # while_loop control flow with parameters
676        self._run_test(
677            model=WhileLoopModels.Parameters(device),
678            inputs=(torch.randn(10, 20),),
679            device=device,
680            dynamic=dynamic,
681        )
682
683    @requires_gpu
684    @parametrize("device", ["cpu", GPU_TYPE])
685    # dynamic=True doesn't work now due to
686    # https://github.com/pytorch/pytorch/issues/123596
687    @parametrize("dynamic", [False])
688    def test_while_loop_with_outer_buffers(self, device, dynamic):
689        # while_loop control flow with outer code
690        self._run_test(
691            model=WhileLoopModels.OuterBuffers(),
692            inputs=(
693                torch.randn(10, 20),
694                torch.randn(10, 20),
695            ),
696            device=device,
697            dynamic=dynamic,
698        )
699
700
701class AssociativeScanTests(TestCase):
702    @requires_gpu
703    @parametrize("combine_mode", ["pointwise", "generic"])
704    @parametrize("backend", ["inductor"])
705    @parametrize("device", [torch.device("cpu"), GPU_TYPE])
706    # This test will fail as flip in combination with particular input lenghts
707    # produces weird results.
708    # This is under investigations in
709    # https://github.com/pytorch/pytorch/issues/131805
710    @decorateIf(unittest.skip, lambda params: params["device"] == GPU_TYPE)
711    def test_associative_scan_CUDA_flip(self, combine_mode, backend, device):
712        def fct(x: torch.Tensor, y: torch.Tensor):
713            return x + y
714
715        for n in range(10):
716            x = torch.arange(n, device=device)
717            torch.compiler.reset()
718            associative_scan1 = torch.compile(
719                associative_scan, backend=backend, fullgraph=True
720            )
721            associative_scan2 = associative_scan
722
723            if combine_mode == "pointwise" and device == torch.device("cpu"):
724                with self.assertRaisesRegex(Exception, r"."):
725                    associative_scan1(
726                        fct, x, 0, reverse=False, combine_mode=combine_mode
727                    )
728
729                # Skipping test because combine_mode currently only suppors CUDA tensors
730                return
731
732            result1 = associative_scan1(
733                fct, x, 0, reverse=False, combine_mode=combine_mode
734            )
735            result2 = associative_scan2(
736                fct, x, 0, reverse=False, combine_mode=combine_mode
737            )
738            result3 = torch.cumsum(x, 0)
739
740            self.assertEqual(result1, result2)
741            self.assertEqual(result1, result3)
742
743            # Flip only non-compiled and compare with compiled reverse=True
744            result1 = associative_scan1(
745                fct, x, 0, reverse=True, combine_mode=combine_mode
746            )
747            result2 = torch.flip(
748                associative_scan2(
749                    fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
750                ),
751                [0],
752            )
753            result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
754
755            self.assertEqual(result1, result2)
756            self.assertEqual(result1, result3)
757
758            # Flip only compiled and compare with non-compiled reverse=True
759            result1 = torch.flip(
760                associative_scan1(
761                    fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
762                ),
763                [0],
764            )
765            result2 = associative_scan2(
766                fct, x, 0, reverse=True, combine_mode=combine_mode
767            )
768            result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
769
770            self.assertEqual(result1, result2)
771            self.assertEqual(result1, result3)
772
773            # Use reverse=False, but flip both results before and after
774            result1 = torch.flip(
775                associative_scan1(
776                    fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
777                ),
778                [0],
779            )
780            result2 = torch.flip(
781                associative_scan2(
782                    fct, torch.flip(x, [0]), 0, reverse=False, combine_mode=combine_mode
783                ),
784                [0],
785            )
786            result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
787
788            self.assertEqual(result1, result2)
789            self.assertEqual(result1, result3)
790
791            # Reverse=True
792            result1 = associative_scan1(
793                fct, x, 0, reverse=True, combine_mode=combine_mode
794            )
795            result2 = associative_scan2(
796                fct, x, 0, reverse=True, combine_mode=combine_mode
797            )
798            result3 = torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0])
799
800            self.assertEqual(result1, result2)
801            self.assertEqual(result1, result3)
802
803
804instantiate_parametrized_tests(CondTests)
805instantiate_parametrized_tests(WhileLoopTests)
806instantiate_parametrized_tests(AssociativeScanTests)
807
808
809if __name__ == "__main__":
810    from torch._inductor.test_case import run_tests
811
812    if HAS_CPU or HAS_GPU:
813        run_tests(needs="filelock")
814