xref: /aosp_15_r20/external/pytorch/test/functorch/test_control_flow.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2import contextlib
3import functools
4import unittest
5
6import torch
7import torch.utils._pytree as pytree
8from functorch.experimental import control_flow
9from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
10from torch._higher_order_ops.associative_scan import associative_scan
11from torch._higher_order_ops.while_loop import while_loop
12from torch._subclasses.functional_tensor import (
13    CppFunctionalizeAPI,
14    FunctionalTensor,
15    FunctionalTensorMode,
16    PythonFunctionalizeAPI,
17)
18from torch.fx.experimental.proxy_tensor import make_fx
19from torch.testing._internal.common_cuda import SM70OrLater
20from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
21from torch.testing._internal.common_utils import (
22    decorateIf,
23    instantiate_parametrized_tests,
24    IS_WINDOWS,
25    parametrize,
26    run_tests,
27    skipIfRocm,
28    skipIfTorchDynamo,
29    TEST_WITH_TORCHDYNAMO,
30    TestCase,
31    xfailIfTorchDynamo,
32)
33
34
35# TODO: pull these helpers from AOTAutograd later
36def to_fun(t):
37    if isinstance(t, torch.Tensor):
38        return FunctionalTensor.to_functional(t)
39    return t
40
41
42def from_fun(t):
43    if not isinstance(t, FunctionalTensor):
44        # quick sanity assert
45        if isinstance(t, torch.Tensor):
46            assert not torch._is_functional_tensor(t)
47        return t
48    torch._sync(t)
49    return torch._from_functional_tensor(t.elem)
50
51
52def to_fun_old(t):
53    if isinstance(t, torch.Tensor) and not torch._is_functional_tensor(t):
54        out = torch._to_functional_tensor(t)
55        torch._mirror_autograd_meta_to(t, out)
56        return out
57    return t
58
59
60def from_fun_old(t):
61    # quick sanity assert
62    if isinstance(t, torch.Tensor):
63        assert torch._is_functional_tensor(t)
64        torch._sync(t)
65        return torch._from_functional_tensor(t)
66    return t
67
68
69def _fake_map(f, x, *args):
70    from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree
71
72    x_pytrees = _unstack_pytree(x)
73    zs = []
74    for xp in x_pytrees:
75        zs.append(f(xp, *args))
76    return _stack_pytree(zs)
77
78
79def _fake_while_loop(cond_fn, body_fn, operands):
80    while cond_fn(*operands):
81        operands = body_fn(*operands)
82    return operands
83
84
85def _fake_associative_scan(combine_fn, input, dim, reverse=False):
86    inp_leaves, spec = pytree.tree_flatten(input)
87    result_flat = []
88    num_leaves = len(inp_leaves)
89    op = reversed if reverse else lambda x: x
90
91    for ind in op(range(inp_leaves[0].size(dim))):
92        r = [
93            inp_leaves[leave_ind][(slice(None),) * dim + (ind,)]
94            for leave_ind in range(num_leaves)
95        ]
96        if (ind > 0 and not reverse) or (
97            ind < (inp_leaves[0].size(dim) - 1) and reverse
98        ):
99            r = combine_fn(
100                pytree.tree_unflatten(result_flat[-1], spec),
101                pytree.tree_unflatten(r, spec),
102            )
103        r_flat, _ = pytree.tree_flatten(r)
104        result_flat.append(r_flat)
105
106    results = [
107        torch.stack([e[leave_ind] for e in op(result_flat)], dim)
108        for leave_ind in range(num_leaves)
109    ]
110    return pytree.tree_unflatten(results, spec)
111
112
113def _while_loop_tests():
114    def simple(x):
115        def cond_fn(x):
116            return x.sum() < 10
117
118        def body_fn(x):
119            return (x + 1,)
120
121        return while_loop(cond_fn, body_fn, (x,))
122
123    def simple_with_mutation(x):
124        def cond_fn(x):
125            y = x.clone().add_(1).add_(-1)
126            return y.sum() < 10
127
128        def body_fn(x):
129            y = x.clone().add_(1).add_(-1)
130            return (y + 1,)
131
132        return while_loop(cond_fn, body_fn, (x,))
133
134    def nested(out_iter, it, y):
135        def cond_fn(out_iter, it, y):
136            return it.sum() < 10
137
138        def body_fn(out_iter, it, y):
139            return (out_iter.clone(), it + y, y + 1)
140
141        def outer_cond_fn(out_iter, it, y):
142            return out_iter.sum() < 2
143
144        def outer_body_fn(out_iter, it, y):
145            out_iter, it, y = while_loop(cond_fn, body_fn, (out_iter, it, y))
146            return (out_iter + 1, it, y)
147
148        return while_loop(outer_cond_fn, outer_body_fn, (out_iter, it, y))
149
150    class Nested(torch.nn.Module):
151        def forward(self, ci, cj, a, b):
152            def cond_fn(i1, j1, x1, y1):
153                return i1 > 0
154
155            def body_fn(i1, j1, x1, y1):
156                def cond_fn_nested(i2, j2, x2, y2):
157                    return j2 > 0
158
159                def body_fn_nested(i2, j2, x2, y2):
160                    return i2.clone(), j2 - 1, x2 + 3.14, y2 - 2.71
161
162                i1, j1, x1, y1 = while_loop(
163                    cond_fn_nested, body_fn_nested, [i1, j1, x1, y1]
164                )
165                return i1 - 1, j1.clone(), x1 * 2, y1 / 2
166
167            return while_loop(cond_fn, body_fn, (ci, cj, a, b))
168
169    class SimpleWithLinear(torch.nn.Module):
170        def __init__(self) -> None:
171            super().__init__()
172            self.linear = torch.nn.Linear(2, 2)
173            self.dec = torch.nn.Buffer(torch.tensor(1))
174
175        def forward(self, iter, x):
176            def cond_fn(it, x):
177                return it - self.dec > 0
178
179            def body_fn(it, x):
180                return it - 1, self.linear(x)
181
182            return while_loop(cond_fn, body_fn, (iter, x))
183
184    class NestedWithLinear(torch.nn.Module):
185        def __init__(self) -> None:
186            super().__init__()
187            self.mod = SimpleWithLinear()
188            self.outer_linear = torch.nn.Linear(2, 2)
189            self.dec = torch.nn.Buffer(torch.tensor(1))
190
191        def forward(self, iter, x):
192            def cond_fn(it, x):
193                return it - self.dec > 0
194
195            def body_fn(it, x):
196                return it - 1, self.outer_linear(self.mod(it, x)[1])
197
198            return while_loop(cond_fn, body_fn, (iter, x))
199
200    nested2 = Nested()
201    simple_with_linear = SimpleWithLinear()
202    nested_with_linear = NestedWithLinear()
203
204    x = torch.zeros(1)
205    y = torch.zeros(1)
206    z = torch.zeros(1)
207    return {
208        "simple": (simple, (x,)),
209        "nested": (nested, (x, y, z)),
210        "nested2": (
211            nested2,
212            (torch.tensor(2), torch.tensor(2), torch.ones(2, 2), torch.ones(2, 2)),
213        ),
214        "simple_with_mutation": (simple_with_mutation, (x,)),
215        "simple_with_linear": (
216            simple_with_linear,
217            (torch.tensor(3), torch.randn(2, 2)),
218        ),
219        "nested_with_linear": (
220            nested_with_linear,
221            (torch.tensor(3), torch.randn(2, 2)),
222        ),
223    }
224
225
226WHILE_LOOP_TESTS = _while_loop_tests()
227
228
229def collect_meta_for_filtered_nodes(
230    gm: torch.fx.GraphModule, node_names, meta_field_name
231):
232    ret = []
233    for mod in gm.modules():
234        for node in mod.graph.nodes:
235            if node.name in node_names:
236                for field_name in meta_field_name:
237                    ret.append(node.meta.get(field_name))
238    return ret
239
240
241def reduce_func(*operands):
242    acc = 0
243    for operand in operands:
244        acc += operand
245    return acc
246
247
248class ReduceObj:
249    def __call__(self, *operands):
250        return reduce_func(*operands)
251
252
253class ReduceMod(torch.nn.Module):
254    def _reduce(self, *operands):
255        return reduce_func(*operands)
256
257    def forward(self, *operands):
258        return self._reduce(*operands)
259
260
261@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
262@skipIfNoDynamoSupport
263class TestControlFlow(TestCase):
264    def setUp(self):
265        torch._dynamo.reset()
266        super().setUp()
267
268    def test_cond_no_trace(self):
269        def true_fn(x):
270            return x.sin()
271
272        def false_fn(x):
273            return x.cos()
274
275        x = torch.randn(4)
276        result = cond(False, true_fn, false_fn, [x])
277        self.assertEqual(result, torch.cos(x))
278
279    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
280    def test_cond_gpu(self):
281        def true_fn(x):
282            return x.sin()
283
284        def false_fn(x):
285            return x.cos()
286
287        x = torch.randn(4, device="cuda")
288        pred = torch.tensor(False, device="cuda")
289        result = cond(pred, true_fn, false_fn, [x])
290        self.assertEqual(result, torch.cos(x))
291
292    def test_cond_autograd_simple(self):
293        def true_fn(x):
294            return x.sin()
295
296        def false_fn(x):
297            return x.cos()
298
299        for pred, fn in zip(
300            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
301        ):
302            x = torch.randn(4, requires_grad=True)
303            result = cond(pred, true_fn, false_fn, (x,))
304            self.assertEqual(result, fn(x))
305
306            grad_out = torch.ones_like(result)
307            grads = torch.autograd.grad(result, (x,), grad_out)
308            expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
309            self.assertEqual(expected_grads, grads)
310
311        def f(pred, x):
312            result = cond(pred, true_fn, false_fn, (x,))
313            grad_out = torch.ones_like(result)
314            return torch.autograd.grad(result, (x,), grad_out)
315
316        gm = make_fx(f, tracing_mode="symbolic")(pred, x)
317
318        self.assertExpectedInline(
319            gm.code.strip(),
320            """\
321def forward(self, pred_1, x_1):
322    true_graph_0 = self.true_graph_0
323    false_graph_0 = self.false_graph_0
324    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,));  true_graph_0 = false_graph_0 = None
325    getitem = cond[0];  cond = None
326    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
327    true_graph_1 = self.true_graph_1
328    false_graph_1 = self.false_graph_1
329    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
330    getitem_1 = cond_1[0];  cond_1 = None
331    return (getitem_1,)""",  # noqa: B950
332        )
333
334    def test_cond_autograd_complex(self):
335        def true_fn(x):
336            return torch.abs((x**2).sin())
337
338        def false_fn(x):
339            return (x + 42).cos()
340
341        for pred, fn in zip(
342            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
343        ):
344            x = torch.randn(4, requires_grad=True)
345            result = cond(pred, true_fn, false_fn, (x,))
346            self.assertEqual(result, fn(x))
347
348            grad_out = torch.ones_like(result)
349            grads = torch.autograd.grad(result, (x,), grad_out)
350            expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
351            self.assertEqual(expected_grads, grads)
352
353        def f(pred, x):
354            result = cond(pred, true_fn, false_fn, (x,))
355            grad_out = torch.ones_like(result)
356            return torch.autograd.grad(result, (x,), grad_out)
357
358        gm = make_fx(f, tracing_mode="symbolic")(pred, x)
359        self.assertExpectedInline(
360            gm.code.strip(),
361            """\
362def forward(self, pred_1, x_1):
363    true_graph_0 = self.true_graph_0
364    false_graph_0 = self.false_graph_0
365    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,));  true_graph_0 = false_graph_0 = None
366    getitem = cond[0];  cond = None
367    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
368    true_graph_1 = self.true_graph_1
369    false_graph_1 = self.false_graph_1
370    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
371    getitem_1 = cond_1[0];  cond_1 = None
372    return (getitem_1,)""",  # noqa: B950
373        )
374
375    @skipIfTorchDynamo("Skip due to graph break when run with dynamo")
376    def test_cond_autograd_nested(self):
377        class Nested(torch.nn.Module):
378            def forward(self, p0, p1, p2, a, b, c):
379                def true_fn(x0, y0, z0):
380                    def true_true_fn(x1, y1, z1):
381                        return (x1 - y1 * z1) * 3.14
382
383                    def true_false_fn(x1, y1, z1):
384                        def true_false_true_fn(x2, y2, z2):
385                            return (x2 * y2 * z2) / 2.71
386
387                        def true_false_false_fn(x2, y2, z2):
388                            return (x2 + y2 + z2) * 1.23
389
390                        return torch.cond(
391                            p2, true_false_true_fn, true_false_false_fn, [x1, y1, z1]
392                        )
393
394                    return torch.cond(p1, true_true_fn, true_false_fn, [x0, y0, z0])
395
396                def false_fn(x0, y0, z0):
397                    def false_true_fn(x1, y1, z1):
398                        def false_true_true_fn(x2, y2, z2):
399                            return (x2 - y2 - z2) + 1.23
400
401                        def false_true_false_fn(x2, y2, z2):
402                            return (x2 / y2 / z2) - 3.14
403
404                        return torch.cond(
405                            p2, false_true_true_fn, false_true_false_fn, [x1, y1, z1]
406                        )
407
408                    def false_false_fn(x1, y1, z1):
409                        return (x1 - y1 * z1) / 2.71
410
411                    return torch.cond(p1, false_true_fn, false_false_fn, [x0, y0, z0])
412
413                return torch.cond(p0, true_fn, false_fn, [a, b, c])
414
415        nn_module = Nested()
416
417        def true_fn(x):
418            return nn_module(
419                torch.tensor(False), torch.tensor(True), torch.tensor(False), x, x, x
420            )
421
422        def false_fn(x):
423            return nn_module(
424                torch.tensor(True), torch.tensor(False), torch.tensor(True), x, x, x
425            )
426
427        x = torch.randn(4, requires_grad=True)
428
429        for pred, fn in zip(
430            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
431        ):
432            result = cond(pred, true_fn, false_fn, (x,))
433            self.assertEqual(result, fn(x))
434
435            grad_out = torch.ones_like(result)
436            grads = torch.autograd.grad(result, (x,), grad_out)
437            expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
438            self.assertEqual(expected_grads, grads)
439
440    @skipIfTorchDynamo("Skip due to graph break when run with dynamo")
441    def test_cond_autograd_mixed_require_grad(self):
442        def true_fn(x, y, z):
443            return x * y * z
444
445        def false_fn(x, y, z):
446            return x + y + z
447
448        x = torch.randn(4, requires_grad=True)
449        y = torch.randn(4, requires_grad=False)
450
451        for pred, fn in zip(
452            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
453        ):
454            result = cond(pred, true_fn, false_fn, (x, y, x))
455            self.assertEqual(result, fn(x, y, x))
456
457            grad_out = torch.ones_like(result)
458            grads = torch.autograd.grad(result, (x,), grad_out)
459            expected_grads = torch.autograd.grad(fn(x, y, x), (x,), grad_out)
460            self.assertEqual(expected_grads, grads)
461
462        def f(pred, x, y, z):
463            result = cond(pred, true_fn, false_fn, (x, y, z))
464            grad_out = torch.ones_like(result)
465            return torch.autograd.grad(result, (x,), grad_out)
466
467        gm = make_fx(f, tracing_mode="symbolic")(pred, x, y, x)
468        self.assertExpectedInline(
469            gm.code.strip(),
470            """\
471def forward(self, pred_1, x_1, y_1, z_1):
472    true_graph_0 = self.true_graph_0
473    false_graph_0 = self.false_graph_0
474    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (z_1, y_1));  true_graph_0 = false_graph_0 = None
475    getitem = cond[0];  cond = None
476    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
477    true_graph_1 = self.true_graph_1
478    false_graph_1 = self.false_graph_1
479    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None
480    getitem_1 = cond_1[0]
481    getitem_2 = cond_1[1];  cond_1 = getitem_2 = None
482    return (getitem_1,)""",  # noqa: B950
483        )
484
485    @skipIfTorchDynamo("Skip due to graph break when run with dynamo")
486    def test_cond_autograd_grad_through_cond(self):
487        nn_module = torch.nn.Linear(4, 4)
488
489        def true_fn(x):
490            return nn_module(x)
491
492        def false_fn(X):
493            return x * nn_module(x)
494
495        x = torch.randn(4, requires_grad=True)
496
497        for pred, fn in zip(
498            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
499        ):
500            result = cond(pred, true_fn, false_fn, (x,))
501            self.assertEqual(result, fn(x))
502
503            grad_out = torch.ones_like(result)
504            grads = torch.autograd.grad(result, (nn_module.weight,), grad_out)
505            expected_grads = torch.autograd.grad(
506                fn(
507                    x,
508                ),
509                (nn_module.weight,),
510                grad_out,
511            )
512            self.assertEqual(expected_grads, grads)
513
514        def f(pred, x):
515            result = cond(pred, true_fn, false_fn, (x,))
516            grad_out = torch.ones_like(result)
517            return torch.autograd.grad(result, (nn_module.weight,), grad_out)
518
519        # need to set _allow_non_fake_inputs = True because model parameters don't
520        # get fakified.
521        gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred, x)
522        self.assertExpectedInline(
523            gm.code.strip(),
524            """\
525def forward(self, pred_1, x_1):
526    true_graph_0 = self.true_graph_0
527    false_graph_0 = self.false_graph_0
528    _param_constant0 = self._param_constant0
529    _param_constant1 = self._param_constant1
530    _tensor_constant0 = self._tensor_constant0
531    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_param_constant0, _param_constant1, x_1, _tensor_constant0));  true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _tensor_constant0 = None
532    getitem = cond[0];  cond = None
533    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
534    true_graph_1 = self.true_graph_1
535    false_graph_1 = self.false_graph_1
536    _param_constant0_1 = self._param_constant0
537    _param_constant1_1 = self._param_constant1
538    _tensor_constant0_1 = self._tensor_constant0
539    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, _tensor_constant0_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = _tensor_constant0_1 = None
540    getitem_1 = cond_1[0];  getitem_1 = None
541    getitem_2 = cond_1[1]
542    getitem_3 = cond_1[2];  getitem_3 = None
543    getitem_4 = cond_1[3];  cond_1 = getitem_4 = None
544    return (getitem_2,)""",  # noqa: B950
545        )
546
547    def test_cond_in_forloop(self):
548        def for_loop_fake(x):
549            for i in range(3):
550                x = x * x + 1
551            return x
552
553        def for_loop_test(x):
554            for i in range(3):
555                pred = i < 3
556
557                def true_fn(x):
558                    return x * x + 1
559
560                def false_fn(x):
561                    return x
562
563                x = cond(pred, true_fn, false_fn, (x,))
564
565            return x
566
567        x = torch.ones(4, requires_grad=True)
568        x_new = for_loop_test(x)
569        x_exp = for_loop_fake(x)
570
571        self.assertEqual(x_new, x_exp)
572
573        grad_out = torch.ones_like(x_new)
574        grads = torch.autograd.grad(x_new, (x,), grad_out)
575        expected_grads = torch.autograd.grad(x_exp, (x,), grad_out)
576        self.assertEqual(expected_grads, grads)
577
578        def f(x):
579            x_new = for_loop_test(x)
580            grad_out = torch.ones_like(x_new)
581            return torch.autograd.grad(x_new, (x,), grad_out)
582
583        gm = make_fx(f, tracing_mode="symbolic")(x)
584        self.assertExpectedInline(
585            gm.code.strip(),
586            """\
587def forward(self, x_1):
588    mul = torch.ops.aten.mul.Tensor(x_1, x_1)
589    add = torch.ops.aten.add.Tensor(mul, 1);  mul = None
590    mul_1 = torch.ops.aten.mul.Tensor(add, add)
591    add_1 = torch.ops.aten.add.Tensor(mul_1, 1);  mul_1 = None
592    mul_2 = torch.ops.aten.mul.Tensor(add_1, add_1)
593    add_2 = torch.ops.aten.add.Tensor(mul_2, 1);  mul_2 = None
594    ones_like = torch.ops.aten.ones_like.default(add_2, pin_memory = False);  add_2 = None
595    mul_3 = torch.ops.aten.mul.Tensor(ones_like, add_1)
596    mul_4 = torch.ops.aten.mul.Tensor(ones_like, add_1);  ones_like = add_1 = None
597    add_3 = torch.ops.aten.add.Tensor(mul_4, mul_3);  mul_4 = mul_3 = None
598    mul_5 = torch.ops.aten.mul.Tensor(add_3, add)
599    mul_6 = torch.ops.aten.mul.Tensor(add_3, add);  add_3 = add = None
600    add_4 = torch.ops.aten.add.Tensor(mul_6, mul_5);  mul_6 = mul_5 = None
601    mul_7 = torch.ops.aten.mul.Tensor(add_4, x_1)
602    mul_8 = torch.ops.aten.mul.Tensor(add_4, x_1);  add_4 = x_1 = None
603    add_5 = torch.ops.aten.add.Tensor(mul_8, mul_7);  mul_8 = mul_7 = None
604    return (add_5,)""",  # noqa: B950
605        )
606
607    @skipIfTorchDynamo("Skip due to graph break when run with dynamo")
608    def test_cond_autograd_pytree_not_all_inputs_used(self):
609        def true_fn(x):
610            return x["t"][0] + x["t"][1]["b"]
611
612        def false_fn(x):
613            return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
614
615        a = torch.randn(4, requires_grad=True)
616        b = torch.randn(4, requires_grad=True)
617        c = torch.randn(4, requires_grad=True)
618
619        for pred, fn in zip(
620            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
621        ):
622            result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
623            self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
624
625            grad_out = torch.ones_like(result)
626            if pred:
627                with self.assertRaisesRegex(Exception, r"."):
628                    grads = torch.autograd.grad(result, (a, b, c), grad_out)
629                    expected_grads = torch.autograd.grad(
630                        fn({"t": [a, {"b": b}, (c,)]}), (a, b, c), grad_out
631                    )
632                    self.assertEqual(expected_grads, grads)
633
634        def f(pred, a, b, c):
635            result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
636            grad_out = torch.ones_like(result)
637            return torch.autograd.grad(result, (a, b), grad_out)
638
639        gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(
640            pred, a, b, c
641        )
642        self.assertExpectedInline(
643            gm.code.strip(),
644            """\
645def forward(self, pred_1, a_1, b_1, c_1):
646    true_graph_0 = self.true_graph_0
647    false_graph_0 = self.false_graph_0
648    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (a_1, b_1, c_1));  true_graph_0 = false_graph_0 = None
649    getitem = cond[0];  cond = None
650    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
651    true_graph_1 = self.true_graph_1
652    false_graph_1 = self.false_graph_1
653    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, c_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = c_1 = None
654    getitem_1 = cond_1[0]
655    getitem_2 = cond_1[1]
656    getitem_3 = cond_1[2];  cond_1 = getitem_3 = None
657    return (getitem_1, getitem_2)""",  # noqa: B950
658        )
659        # Forward
660        self.assertExpectedInline(
661            gm.true_graph_0.code.strip(),
662            """\
663def forward(self, arg0_1, arg1_1, arg2_1):
664    add = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
665    return (add,)""",
666        )
667        # Backward
668        self.assertExpectedInline(
669            gm.true_graph_1.code.strip(),
670            """\
671def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
672    add = torch.ops.aten.add.Tensor(arg1_1, arg2_1);  arg1_1 = arg2_1 = add = None
673    clone = torch.ops.aten.clone.default(arg0_1)
674    clone_1 = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
675    return [clone, clone_1, None]""",
676        )
677
678    def test_cond_autograd_pytree_input(self):
679        def true_fn(x):
680            return x["t"][0] + x["t"][1]["b"] * x["t"][2][0]
681
682        def false_fn(x):
683            return x["t"][0] * (x["t"][2][0] / x["t"][1]["b"])
684
685        a = torch.randn(4, requires_grad=True)
686        b = torch.randn(4, requires_grad=True)
687        c = torch.randn(4, requires_grad=True)
688
689        for pred, fn in zip(
690            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
691        ):
692            result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
693            self.assertEqual(result, fn({"t": [a, {"b": b}, (c,)]}))
694
695            grad_out = torch.ones_like(result)
696            grads = torch.autograd.grad(result, (a, b), grad_out)
697            expected_grads = torch.autograd.grad(
698                fn({"t": [a, {"b": b}, (c,)]}), (a, b), grad_out
699            )
700            self.assertEqual(expected_grads, grads)
701
702        def f(pred):
703            result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
704            grad_out = torch.ones_like(result)
705            return torch.autograd.grad(result, (a, b), grad_out)
706
707        # need to set _allow_non_fake_inputs = True because model parameters don't
708        # get fakified.
709        gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
710        self.assertExpectedInline(
711            gm.code.strip(),
712            """\
713def forward(self, pred_1):
714    true_graph_0 = self.true_graph_0
715    false_graph_0 = self.false_graph_0
716    _tensor_constant0 = self._tensor_constant0
717    _tensor_constant1 = self._tensor_constant1
718    _tensor_constant2 = self._tensor_constant2
719    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2));  true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
720    getitem = cond[0];  cond = None
721    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
722    true_graph_1 = self.true_graph_1
723    false_graph_1 = self.false_graph_1
724    _tensor_constant0_1 = self._tensor_constant0
725    _tensor_constant1_1 = self._tensor_constant1
726    _tensor_constant2_1 = self._tensor_constant2
727    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = None
728    getitem_1 = cond_1[0]
729    getitem_2 = cond_1[1]
730    getitem_3 = cond_1[2];  cond_1 = getitem_3 = None
731    return (getitem_1, getitem_2)""",  # noqa: B950
732        )
733
734    def test_cond_autograd_different_pytree_output(self):
735        def true_fn(x):
736            return x["t"][0], {"r": x["t"][2][0] / x["t"][1]["b"]}, [x["t"][2][0]]
737
738        def false_fn(x):
739            return {"res": [x["t"][0] * x["t"][1]["b"], x["t"][2][0]]}
740
741        a = torch.randn(4, requires_grad=True)
742        b = torch.randn(4, requires_grad=True)
743        c = torch.randn(4, requires_grad=True)
744
745        for pred, fn in zip(
746            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
747        ):
748            with self.assertRaisesRegex(
749                torch._dynamo.exc.UncapturedHigherOrderOpError,
750                "Cond doesn't work unless it is captured completely with torch.compile",
751            ):
752                cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
753
754    @skipIfTorchDynamo("Skip due to graph break when run with dynamo")
755    def test_cond_autograd_same_pytree_output(self):
756        def true_fn(x):
757            return {"res": [x["t"][0], (x["t"][2][0],)]}
758
759        def false_fn(x):
760            return {"res": [x["t"][1]["b"], (x["t"][2][0],)]}
761
762        a = torch.randn(4, requires_grad=True)
763        b = torch.randn(4, requires_grad=True)
764        c = torch.randn(4, requires_grad=True)
765
766        for pred, fn in zip(
767            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
768        ):
769            result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
770            result_exp = fn({"t": [a, {"b": b}, (c,)]})
771            self.assertEqual(result, result_exp)
772
773            result_flat, _ = pytree.tree_flatten(result)
774            result_exp_flat, _ = pytree.tree_flatten(result_exp)
775
776            grad_out = [torch.ones_like(g) for g in result_flat]
777            expected_grads = torch.autograd.grad(result_exp_flat, (c,), grad_out)
778            grads = torch.autograd.grad(result_flat, (c,), grad_out)
779            self.assertEqual(expected_grads, grads)
780
781        def f(pred):
782            result = cond(pred, true_fn, false_fn, ({"t": [a, {"b": b}, (c,)]},))
783            return result
784
785        gm = make_fx(f, tracing_mode="symbolic", _allow_non_fake_inputs=True)(pred)
786        self.assertExpectedInline(
787            gm.code.strip(),
788            """\
789def forward(self, pred_1):
790    true_graph_0 = self.true_graph_0
791    false_graph_0 = self.false_graph_0
792    _tensor_constant0 = self._tensor_constant0
793    _tensor_constant1 = self._tensor_constant1
794    _tensor_constant2 = self._tensor_constant2
795    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2));  pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None
796    getitem = cond[0]
797    getitem_1 = cond[1];  cond = None
798    view = torch.ops.aten.view.default(getitem, [4]);  getitem = None
799    view_1 = torch.ops.aten.view.default(getitem_1, [4]);  getitem_1 = None
800    return {'res': [view, (view_1,)]}""",  # noqa: B950
801        )
802
803    @skipIfTorchDynamo("Skip due to graph break when run with dynamo")
804    def test_cond_autograd_torch_nn_module(self):
805        nn_module_true = torch.nn.Linear(4, 4)
806
807        def true_fn(x):
808            return nn_module_true(torch.abs((x**2).sin()))
809
810        nn_module_false = torch.nn.GRUCell(4, 4)
811
812        def false_fn(x):
813            return nn_module_false((x + 42).cos())
814
815        for pred, fn in zip(
816            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
817        ):
818            x = torch.randn(4, requires_grad=True)
819            result = cond(pred, true_fn, false_fn, (x,))
820            self.assertEqual(result, fn(x))
821
822            grad_out = torch.ones_like(result)
823            grads = torch.autograd.grad(result, (x,), grad_out)
824            expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
825            self.assertEqual(expected_grads, grads)
826
827        def f(pred, x):
828            result = cond(pred, true_fn, false_fn, (x,))
829            grad_out = torch.ones_like(result)
830            return torch.autograd.grad(result, (x,), grad_out)
831
832        gm = make_fx(f)(pred, x)
833        self.assertExpectedInline(
834            gm.code.strip(),
835            """\
836def forward(self, pred_1, x_1):
837    true_graph_0 = self.true_graph_0
838    false_graph_0 = self.false_graph_0
839    _param_constant0 = self._param_constant0
840    _param_constant1 = self._param_constant1
841    _param_constant2 = self._param_constant2
842    _param_constant3 = self._param_constant3
843    _param_constant4 = self._param_constant4
844    _param_constant5 = self._param_constant5
845    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1, _param_constant0, _param_constant1, _param_constant2, _param_constant3, _param_constant4, _param_constant5));  true_graph_0 = false_graph_0 = _param_constant0 = _param_constant1 = _param_constant2 = _param_constant3 = _param_constant4 = _param_constant5 = None
846    getitem = cond[0];  cond = None
847    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
848    true_graph_1 = self.true_graph_1
849    false_graph_1 = self.false_graph_1
850    _param_constant0_1 = self._param_constant0
851    _param_constant1_1 = self._param_constant1
852    _param_constant2_1 = self._param_constant2
853    _param_constant3_1 = self._param_constant3
854    _param_constant4_1 = self._param_constant4
855    _param_constant5_1 = self._param_constant5
856    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = None
857    getitem_1 = cond_1[0]
858    getitem_2 = cond_1[1];  getitem_2 = None
859    getitem_3 = cond_1[2];  getitem_3 = None
860    getitem_4 = cond_1[3];  getitem_4 = None
861    getitem_5 = cond_1[4];  getitem_5 = None
862    getitem_6 = cond_1[5];  getitem_6 = None
863    getitem_7 = cond_1[6];  cond_1 = getitem_7 = None
864    return (getitem_1,)""",  # noqa: B950
865        )
866
867    def test_cond_autograd_user_nn_module(self):
868        class User_nn_module(torch.nn.Module):
869            def __init__(self) -> None:
870                super().__init__()
871
872            def forward(self, input):
873                return input * input
874
875        nn_module_true = User_nn_module()
876
877        def true_fn(x):
878            return nn_module_true(torch.abs((x**2).sin()))
879
880        nn_module_false = torch.nn.ReLU(inplace=False)
881
882        def false_fn(x):
883            return nn_module_false((x + 42).cos())
884
885        for pred, fn in zip(
886            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
887        ):
888            x = torch.randn(4, requires_grad=True)
889            result = cond(pred, true_fn, false_fn, (x,))
890            self.assertEqual(result, fn(x))
891
892            grad_out = torch.ones_like(result)
893            grads = torch.autograd.grad(result, (x,), grad_out)
894            expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
895            self.assertEqual(expected_grads, grads)
896
897        def f(pred, x):
898            result = cond(pred, true_fn, false_fn, (x,))
899            grad_out = torch.ones_like(result)
900            return torch.autograd.grad(result, (x,), grad_out)
901
902        gm = make_fx(f)(pred, x)
903        self.assertExpectedInline(
904            gm.code.strip(),
905            """\
906def forward(self, pred_1, x_1):
907    true_graph_0 = self.true_graph_0
908    false_graph_0 = self.false_graph_0
909    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,));  true_graph_0 = false_graph_0 = None
910    getitem = cond[0];  cond = None
911    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
912    true_graph_1 = self.true_graph_1
913    false_graph_1 = self.false_graph_1
914    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
915    getitem_1 = cond_1[0];  cond_1 = None
916    return (getitem_1,)""",  # noqa: B950
917        )
918
919    def test_cond_autograd_inner_fn(self):
920        def true_fn(x):
921            return torch.abs((x**2).sin())
922
923        def false_fn(x):
924            def inner_fn(x):
925                return x**2
926
927            return torch.abs(inner_fn(x).sin())
928
929        x = torch.randn(4, requires_grad=True)
930        pred = torch.tensor(False)
931        fn = false_fn
932        result_false = cond(pred, true_fn, false_fn, (x,))
933        self.assertEqual(result_false, fn(x))
934
935        grad_out = torch.ones_like(result_false)
936        grads_false = torch.autograd.grad(result_false, (x,), grad_out)
937        expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
938        self.assertEqual(expected_grads, grads_false)
939
940        pred = torch.tensor(True)
941        fn = true_fn
942        result_true = cond(pred, true_fn, false_fn, (x,))
943        self.assertEqual(result_true, fn(x))
944        self.assertEqual(result_false, result_true)
945
946        grad_out = torch.ones_like(result_true)
947        grads_true = torch.autograd.grad(result_true, (x,), grad_out)
948        expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
949        self.assertEqual(expected_grads, grads_true)
950        self.assertEqual(grads_false, grads_true)
951
952        def f(pred, x):
953            result = cond(pred, true_fn, false_fn, (x,))
954            grad_out = torch.ones_like(result)
955            return torch.autograd.grad(result, (x,), grad_out)
956
957        gm = make_fx(f)(pred, x)
958        self.assertExpectedInline(
959            gm.code.strip(),
960            """\
961def forward(self, pred_1, x_1):
962    true_graph_0 = self.true_graph_0
963    false_graph_0 = self.false_graph_0
964    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,));  true_graph_0 = false_graph_0 = None
965    getitem = cond[0];  cond = None
966    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
967    true_graph_1 = self.true_graph_1
968    false_graph_1 = self.false_graph_1
969    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
970    getitem_1 = cond_1[0];  cond_1 = None
971    return (getitem_1,)""",  # noqa: B950
972        )
973
974    def test_cond_autograd_inner_tensor(self):
975        def true_fn(x):
976            return torch.abs((x**2).sin())
977
978        def false_fn(x):
979            y = torch.ones(4, requires_grad=False) * 42
980            return (x * y).cos()
981
982        for pred, fn in zip(
983            [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn]
984        ):
985            x = torch.randn(4, requires_grad=True)
986            result = cond(pred, true_fn, false_fn, (x,))
987            self.assertEqual(result, fn(x))
988
989            grad_out = torch.ones_like(result)
990            grads = torch.autograd.grad(result, (x,), grad_out)
991            expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
992            self.assertEqual(expected_grads, grads)
993
994        def f(pred, x):
995            result = cond(pred, true_fn, false_fn, (x,))
996            grad_out = torch.ones_like(result)
997            return torch.autograd.grad(result, (x,), grad_out)
998
999        gm = make_fx(f, tracing_mode="symbolic")(pred, x)
1000        self.assertExpectedInline(
1001            gm.code.strip(),
1002            """\
1003def forward(self, pred_1, x_1):
1004    true_graph_0 = self.true_graph_0
1005    false_graph_0 = self.false_graph_0
1006    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,));  true_graph_0 = false_graph_0 = None
1007    getitem = cond[0];  cond = None
1008    ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False);  getitem = None
1009    true_graph_1 = self.true_graph_1
1010    false_graph_1 = self.false_graph_1
1011    cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1));  pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
1012    getitem_1 = cond_1[0];  cond_1 = None
1013    return (getitem_1,)""",  # noqa: B950
1014        )
1015
1016    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1017    def test_cond_autograd_gpu(self):
1018        def true_fn(x):
1019            return x.sin()
1020
1021        def false_fn(x):
1022            return x.cos()
1023
1024        for pred, fn in zip(
1025            [torch.tensor(False, device="cuda"), torch.tensor(True, device="cuda")],
1026            [false_fn, true_fn],
1027        ):
1028            x = torch.randn(4, requires_grad=True, device="cuda")
1029            result = cond(pred, true_fn, false_fn, (x,))
1030            self.assertEqual(result, fn(x))
1031
1032            grad_out = torch.ones_like(result)
1033            grads = torch.autograd.grad(result, (x,), grad_out)
1034            expected_grads = torch.autograd.grad(fn(x), (x,), grad_out)
1035            self.assertEqual(expected_grads, grads)
1036
1037    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1038    def test_map_gpu(self):
1039        def f(x, y):
1040            return x + y
1041
1042        xs = torch.ones(3, 2, 2, device="cuda")
1043        y = torch.ones(2, device="cuda")
1044        res = control_flow.map(f, xs, y)
1045        expected = _fake_map(f, xs, y)
1046        self.assertEqual(expected, res)
1047
1048    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1049    def test_while_loop_gpu(self):
1050        def cond_fn(x):
1051            return x.sum() < 10
1052
1053        def body_fn(x):
1054            return (x + 1,)
1055
1056        x = torch.zeros(1, device="cuda")
1057        res = while_loop(cond_fn, body_fn, (x,))
1058        expected = _fake_while_loop(cond_fn, body_fn, (x,))
1059        self.assertEqual(expected, res)
1060
1061    def test_map_illegal_inputs(self):
1062        def f(x, y):
1063            return x[0] + x[1] + y
1064
1065        with self.assertRaisesRegex(
1066            RuntimeError,
1067            r"Mapped xs can only consist of tensors\. Got xs \[3, tensor\(\[1\., 1\.\]\)\]\.",
1068        ):
1069            _ = control_flow.map(f, (3, torch.ones(2)), torch.ones(2))
1070
1071        with self.assertRaisesRegex(
1072            RuntimeError, r"Leading dimensions of mapped xs cannot be 0\."
1073        ):
1074            _ = control_flow.map(
1075                f, (torch.ones(0, 1, 2), torch.ones(0, 1, 2)), torch.ones(2)
1076            )
1077
1078        with self.assertRaisesRegex(
1079            RuntimeError,
1080            r"Leading dimensions of mapped xs must be consistent\. "
1081            r"Got shapes \[torch\.Size\(\[3, 4, 5\]\), torch\.Size\(\[4, 4, 5\]\)\]\.",
1082        ):
1083            _ = control_flow.map(
1084                f, (torch.ones(3, 4, 5), torch.ones(4, 4, 5)), torch.ones(5)
1085            )
1086
1087    def test_map_illegal_outputs(self):
1088        def f(x, y):
1089            return x.item()
1090
1091        def f1(x, y):
1092            return y.size()
1093
1094        def f2(x, y):
1095            return None
1096
1097        x = torch.ones([3])
1098        y = torch.ones([1, 2, 3])
1099        with self.assertRaisesRegex(
1100            RuntimeError, r"Expect outputs of map only contains tensors or None\."
1101        ):
1102            _ = control_flow.map(f, x, y)
1103
1104        with self.assertRaisesRegex(
1105            RuntimeError, r"Expect outputs of map only contains tensors or None\."
1106        ):
1107            out = control_flow.map(f1, x, y)
1108
1109        # return None is OK
1110        _ = control_flow.map(f2, x, y)
1111
1112    def test_map_list_in_out(self):
1113        def f(x, y):
1114            return [[x[0][0] + y]]
1115
1116        xs = [[torch.ones(3, 2, 2)]]
1117        y = torch.ones(2)
1118        res = control_flow.map(f, xs, y)
1119        expected = _fake_map(f, xs, y)
1120        self.assertEqual(len(res), 1)
1121        self.assertEqual(len(res[0]), 1)
1122        self.assertEqual(expected, res)
1123
1124    def test_map_dict_in_out(self):
1125        def f(x, y):
1126            return {"c": x["a"]["b"] + y}
1127
1128        xs = {"a": {"b": torch.ones(3, 2, 2)}}
1129        y = torch.ones(2)
1130        res = control_flow.map(f, xs, y)
1131        expected = _fake_map(f, xs, y)
1132        self.assertEqual(len(res), 1)
1133        self.assertTrue("c" in res)
1134        self.assertEqual(expected, res)
1135
1136    def test_map_autograd_simple(self):
1137        def f(x, y):
1138            return x.sin().cos() * y.cos().sin()
1139
1140        xs = torch.ones(3, 2, 2, requires_grad=True)
1141        y = torch.ones(2, requires_grad=True)
1142        res = control_flow.map(f, xs, y)
1143        expected_res = _fake_map(f, xs, y)
1144        grad_out = torch.ones_like(res)
1145        grads = torch.autograd.grad(res, (xs, y), grad_out)
1146        expected_grads = torch.autograd.grad(expected_res, (xs, y), grad_out)
1147        self.assertEqual(expected_res, res)
1148        self.assertEqual(expected_grads, grads)
1149
1150    def test_map_autograd_simple_partial_grad(self):
1151        def f(x, y):
1152            return x.sin().cos() * y.cos().sin()
1153
1154        xs = torch.ones(3, 2, 2, requires_grad=True)
1155        # Disable the gradient computation for y
1156        y = torch.ones(2, requires_grad=False)
1157        res = control_flow.map(f, xs, y)
1158        expected_res = _fake_map(f, xs, y)
1159        grad_out = torch.ones_like(res)
1160        grads = torch.autograd.grad(res, (xs,), grad_out)
1161        expected_grads = torch.autograd.grad(expected_res, (xs,), grad_out)
1162        self.assertEqual(expected_res, res)
1163        self.assertEqual(expected_grads, grads)
1164
1165    def test_map_autograd_no_grad_output(self):
1166        def f(x, y):
1167            return x[0].sin().cos() + y, y.cos().sin()
1168
1169        xs = [torch.ones(3, 2, 2, requires_grad=True), torch.ones(3, 3)]
1170        # Disable the gradient computation for y
1171        y = torch.ones(2, requires_grad=False)
1172        res = control_flow.map(f, xs, y)
1173        expected_res = _fake_map(f, xs, y)
1174        grad_out = torch.ones_like(res[0])
1175        grads = torch.autograd.grad(res[0], (xs[0],), grad_out)
1176        expected_grads = torch.autograd.grad(expected_res[0], (xs[0],), grad_out)
1177        self.assertEqual(expected_res, res)
1178        self.assertEqual(expected_grads, grads)
1179
1180    def test_map_autograd_nested_list(self):
1181        import torch.utils._pytree as pytree
1182
1183        def f(x, y):
1184            a, b = x
1185            c, d = a
1186            return [[b.sin() * c.cos()], d.sin() * y.cos()]
1187
1188        def fwbw(map_op, f, x, y):
1189            z = map_op(f, x, y)
1190            flat_x = pytree.tree_leaves(x)
1191            flat_z = pytree.tree_leaves(z)
1192            grads = torch.autograd.grad(
1193                flat_z, flat_x, [torch.ones_like(z) for z in flat_z]
1194            )
1195            return z, grads
1196
1197        x = [
1198            [
1199                torch.randn(3, 2, 2, requires_grad=True),
1200                torch.randn(3, 2, 1, requires_grad=True),
1201            ],
1202            torch.ones(3, 1, 2, requires_grad=True),
1203        ]
1204        y = torch.ones(1, requires_grad=True)
1205        true_outs = fwbw(control_flow.map, f, x, y)
1206        fake_outs = fwbw(_fake_map, f, x, y)
1207        self.assertEqual(true_outs, fake_outs)
1208
1209    @unittest.skipIf(not SM70OrLater, "triton")
1210    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1211    @parametrize("reverse", [False, True])
1212    @parametrize("combine_mode", ["pointwise", "generic"])
1213    @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
1214    # Skipping the combination of combine_mode=pointwise and device=cpu
1215    # as the current implementation of pointwise does only support CUDA device
1216    @decorateIf(
1217        unittest.skip,
1218        lambda params: (
1219            params["combine_mode"] == "pointwise"
1220            and (params["device"] == torch.device("cpu") or torch.version.hip)
1221        ),
1222    )
1223    def test_pointwise_associative_scan_simple(self, reverse, combine_mode, device):
1224        def add(x: torch.Tensor, y: torch.Tensor):
1225            return x + y
1226
1227        def mul(x: torch.Tensor, y: torch.Tensor):
1228            return x * y
1229
1230        x = torch.randn(3, 10, 2, device=device)
1231
1232        for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]:
1233            result = associative_scan(
1234                op, x, 0, reverse=reverse, combine_mode=combine_mode
1235            )
1236            result_exp = _fake_associative_scan(op, x, 0, reverse=reverse)
1237            self.assertEqual(result, result_exp)
1238
1239        # Jax Examples
1240        x = torch.arange(0, 4, device=device)
1241        cumsum1 = associative_scan(
1242            add, x, 0, reverse=reverse, combine_mode=combine_mode
1243        )
1244        cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse)
1245        if not reverse:
1246            self.assertEqual(
1247                cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64)
1248            )
1249        else:
1250            self.assertEqual(
1251                cumsum1, torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64)
1252            )
1253        self.assertEqual(cumsum1, cumsum_exp)
1254
1255    @unittest.skipIf(not SM70OrLater, "triton")
1256    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1257    @parametrize("reverse", [False, True])
1258    @parametrize("combine_mode", ["pointwise", "generic"])
1259    @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
1260    # Skipping the combination of combine_mode=pointwise and device=cpu
1261    # as the current implementation of pointwise does only support CUDA device
1262    @decorateIf(
1263        unittest.skip,
1264        lambda params: (
1265            params["combine_mode"] == "pointwise"
1266            and (params["device"] == torch.device("cpu") or torch.version.hip)
1267        ),
1268    )
1269    def test_pointwise_associative_scan_dim(self, reverse, combine_mode, device):
1270        import random
1271
1272        def add(x: torch.Tensor, y: torch.Tensor):
1273            return x + y
1274
1275        def mul(x: torch.Tensor, y: torch.Tensor):
1276            return x * y
1277
1278        num_dims = [random.randint(2, 5) for _ in range(10)]
1279        for num_dim in num_dims:
1280            shapes = [random.randint(1, 10) for _ in range(num_dim)]
1281            rnd_scan_dim = random.randint(0, num_dim - 1)
1282            x = torch.randn(*shapes, device=device)
1283
1284            for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]:
1285                result = associative_scan(
1286                    op, x, rnd_scan_dim, reverse=reverse, combine_mode=combine_mode
1287                )
1288                result_exp = _fake_associative_scan(
1289                    op, x, rnd_scan_dim, reverse=reverse
1290                )
1291                self.assertEqual(result, result_exp)
1292                if not reverse:
1293                    result_exp_PT = op_pt(x, rnd_scan_dim)
1294                    self.assertEqual(result, result_exp_PT)
1295
1296    @unittest.skipIf(not SM70OrLater, "triton")
1297    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1298    @parametrize("reverse", [False, True])
1299    @parametrize("combine_mode", ["pointwise", "generic"])
1300    @parametrize("compile_mode", ["compile", "compile_dynamic_shape"])
1301    @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
1302    # Skipping the combination of combine_mode=pointwise and device=cpu
1303    # as the current implementation of pointwise does only support CUDA device
1304    @decorateIf(
1305        unittest.skip,
1306        lambda params: (
1307            params["combine_mode"] == "pointwise"
1308            and (params["device"] == torch.device("cpu") or torch.version.hip)
1309        ),
1310    )
1311    def test_pointwise_associative_scan_compile(
1312        self, reverse, combine_mode, compile_mode, device
1313    ):
1314        def add(x: torch.Tensor, y: torch.Tensor):
1315            return x + y
1316
1317        def mul(x: torch.Tensor, y: torch.Tensor):
1318            return x * y
1319
1320        x = torch.randn(3, 10, 2, device=device)
1321        torch.compiler.reset()
1322        if compile_mode == "compile":
1323            associative_scan_fct = torch.compile(
1324                associative_scan, fullgraph=True, dynamic=False
1325            )
1326        else:
1327            associative_scan_fct = torch.compile(
1328                associative_scan, fullgraph=True, dynamic=True
1329            )
1330
1331        for op, op_pt in [(add, torch.cumsum), (mul, torch.cumprod)]:
1332            result = associative_scan_fct(
1333                op, x, 0, reverse=reverse, combine_mode=combine_mode
1334            )
1335            result_exp = _fake_associative_scan(op, x, 0, reverse=reverse)
1336            self.assertEqual(result, result_exp)
1337            if not reverse:
1338                result_exp_PT = op_pt(x, 0)
1339                self.assertEqual(result, result_exp_PT)
1340
1341        # Jax Examples
1342        x = torch.arange(0, 4, device=device)
1343        cumsum1 = associative_scan(
1344            add, x, 0, reverse=reverse, combine_mode=combine_mode
1345        )
1346        cumsum_exp = _fake_associative_scan(add, x, 0, reverse=reverse)
1347        if not reverse:
1348            self.assertEqual(
1349                cumsum1, torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64)
1350            )
1351        else:
1352            self.assertEqual(
1353                cumsum1, torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64)
1354            )
1355        self.assertEqual(cumsum1, cumsum_exp)
1356
1357    @skipIfRocm(msg="Unsupported on ROCM yet")
1358    @unittest.skipIf(not SM70OrLater, "triton")
1359    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1360    @parametrize("reverse", [False, True])
1361    @parametrize("combine_mode", ["pointwise", "generic"])
1362    @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
1363    # Skipping the combination of combine_mode=pointwise and device=cpu
1364    # as the current implementation of pointwise does only support CUDA device
1365    @decorateIf(
1366        unittest.skip,
1367        lambda params: (
1368            params["combine_mode"] == "pointwise"
1369            and (params["device"] == torch.device("cpu") or torch.version.hip)
1370        ),
1371    )
1372    def test_pointwise_associative_scan_binary_operator(
1373        self, reverse, combine_mode, device
1374    ):
1375        def fct(x, y):
1376            A_i, Bu_i = x
1377            A_j, Bu_j = y
1378            return A_j * A_i, A_j * Bu_i + Bu_j
1379
1380        torch.compiler.reset()
1381        associative_scan1 = torch.compile(associative_scan, fullgraph=True)
1382        associative_scan2 = associative_scan
1383
1384        state_dim = 20
1385        timesteps = 10
1386        projected_inputs = torch.randn(
1387            timesteps, state_dim, requires_grad=True, device=device
1388        )
1389        A = torch.randn(state_dim, requires_grad=True, device=device)
1390        elements = (A.repeat((timesteps, 1)), projected_inputs)
1391
1392        result1 = associative_scan1(
1393            fct, elements, 0, combine_mode=combine_mode, reverse=reverse
1394        )
1395        result2 = associative_scan2(
1396            fct, elements, 0, combine_mode=combine_mode, reverse=reverse
1397        )
1398        expected_result = _fake_associative_scan(fct, elements, 0, reverse=reverse)
1399        self.assertEqual(
1400            result1,
1401            expected_result,
1402        )
1403        self.assertEqual([r.device.type for r in result1], [device.type] * len(result1))
1404        self.assertEqual(
1405            result2,
1406            expected_result,
1407        )
1408        self.assertEqual([r.device.type for r in result2], [device.type] * len(result2))
1409
1410    @skipIfRocm(msg="Unsupported on ROCM yet")
1411    @unittest.skipIf(not SM70OrLater, "triton")
1412    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1413    @parametrize("reverse", [False, True])
1414    @parametrize("combine_mode", ["pointwise", "generic"])
1415    @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
1416    # Skipping the combination of combine_mode=pointwise and device=cpu
1417    # as the current implementation of pointwise does only support CUDA device
1418    @decorateIf(
1419        unittest.skip,
1420        lambda params: (
1421            params["combine_mode"] == "pointwise"
1422            and (params["device"] == torch.device("cpu") or torch.version.hip)
1423        ),
1424    )
1425    def test_pointwise_associative_scan_tuple(self, reverse, combine_mode, device):
1426        def fct(x, y):
1427            return (x[0] + y[0], x[1] * y[1])
1428
1429        x = torch.randn(3, 2, 2, device=device, requires_grad=True)
1430        y = torch.randn(3, 2, 2, device=device, requires_grad=True)
1431        inp = (x, y)
1432
1433        result1 = associative_scan(
1434            fct, inp, 0, reverse=reverse, combine_mode=combine_mode
1435        )
1436        expected_result = _fake_associative_scan(fct, inp, 0, reverse=reverse)
1437        self.assertEqual(result1, expected_result)
1438
1439    @skipIfRocm(msg="Unsupported on ROCM yet")
1440    @unittest.skipIf(not SM70OrLater, "triton")
1441    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1442    @parametrize("reverse", [False, True])
1443    @parametrize("combine_mode", ["pointwise", "generic"])
1444    @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
1445    # Skipping the combination of combine_mode=pointwise and device=cpu
1446    # as the current implementation of pointwise does only support CUDA device
1447    @decorateIf(
1448        unittest.skip,
1449        lambda params: (
1450            params["combine_mode"] == "pointwise"
1451            and (params["device"] == torch.device("cpu") or torch.version.hip)
1452        ),
1453    )
1454    def test_pointwise_associative_scan_complex_pytree(
1455        self, reverse, combine_mode, device
1456    ):
1457        def fct_wrong_pytree(x, y):
1458            return {
1459                "i": x["i"] * y["j"][0][0],
1460                "k": 0.0,
1461                "j": ([x["j"][1][0]["o"]], [{"o": torch.sin(x["i"])}]),
1462            }
1463
1464        def fct_pointwise(x, y):
1465            return {
1466                "i": x["i"] * y["i"],
1467                "j": (
1468                    [x["j"][0][0] * y["j"][0][0]],
1469                    [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}],
1470                ),
1471            }
1472
1473        x = torch.randn(3, 2, 2, device=device, requires_grad=True)
1474        y = torch.randn(3, 2, 2, device=device, requires_grad=True)
1475        z = torch.randn(3, 2, 2, device=device, requires_grad=True)
1476        inp = {"i": x, "j": ([y], [{"o": z}])}
1477
1478        with self.assertRaisesRegex(Exception, r"."):
1479            result = associative_scan(fct_wrong_pytree, inp, 0, combine_mode="generic")
1480
1481        torch.compiler.reset()
1482        associative_scan1 = torch.compile(associative_scan, fullgraph=True)
1483        associative_scan2 = associative_scan
1484
1485        result1 = associative_scan1(
1486            fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse
1487        )
1488        result2 = associative_scan2(
1489            fct_pointwise, inp, 0, combine_mode=combine_mode, reverse=reverse
1490        )
1491        expected_result = _fake_associative_scan(fct_pointwise, inp, 0, reverse=reverse)
1492        self.assertEqual(result1, expected_result)
1493        self.assertEqual(result2, expected_result)
1494
1495    @unittest.skipIf(not SM70OrLater, "triton")
1496    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
1497    @parametrize("reverse", [False, True])
1498    @parametrize("device", [torch.device("cpu"), torch.device("cuda")])
1499    def test_generic_associative_scan_generic_simple(self, reverse, device):
1500        def non_pointwise(x: torch.Tensor, y: torch.Tensor):
1501            W = torch.diag(torch.ones(2, device=device))
1502            return x @ W + y @ W
1503
1504        x = torch.randn(3, 10, 2, device=device)
1505        with self.assertRaisesRegex(Exception, ".*"):
1506            out = associative_scan(
1507                non_pointwise, x, 0, reverse=reverse, combine_mode="pointwise"
1508            )
1509
1510        result1 = associative_scan(
1511            non_pointwise, x, 0, reverse=reverse, combine_mode="generic"
1512        )
1513        result_expected = _fake_associative_scan(non_pointwise, x, 0, reverse=reverse)
1514        self.assertEqual(result1, result_expected)
1515
1516
1517@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
1518@skipIfNoDynamoSupport
1519class TestControlFlowTraced(TestCase):
1520    def setUp(self):
1521        torch._dynamo.reset()
1522        super().setUp()
1523
1524    def _check_tracing(self, fn, args, allow_non_fake_inputs=False):
1525        graphs = {}
1526        eager_res = fn(*args)
1527        for tracing_mode in ["symbolic", "real", "fake"]:
1528            graph = make_fx(
1529                fn,
1530                tracing_mode=tracing_mode,
1531                _allow_non_fake_inputs=allow_non_fake_inputs,
1532            )(*args)
1533            graphs[tracing_mode] = graph
1534            self.assertEqual(graph(*args), eager_res)
1535        return graphs
1536
1537    def _check_compile(self, fn, args, *, backend="eager"):
1538        eager_res = fn(*args)
1539        compiled_fn = torch.compile(fn, backend=backend)
1540        self.assertEqual(compiled_fn(*args), eager_res)
1541
1542    def test_cond_traced_not_nested(self):
1543        def true_fn(x):
1544            return x.sin()
1545
1546        def false_fn(x):
1547            return x.cos()
1548
1549        def f(x, y):
1550            return cond(y, true_fn, false_fn, [x])
1551
1552        x = torch.randn(4)
1553        graph = make_fx(f)(x, torch.tensor(False))
1554        result_true = graph.forward(x, torch.tensor(True))
1555        result_false = graph.forward(x, torch.tensor(False))
1556        self.assertFalse(torch.allclose(result_true, result_false))
1557        self.assertEqual(result_true, torch.sin(x))
1558        self.assertEqual(result_false, torch.cos(x))
1559
1560        graph = make_fx(f, tracing_mode="symbolic")(x, torch.tensor(False))
1561        self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))
1562
1563    @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
1564    def test_cond_simple_with_linear_compile_check_graph(self):
1565        from torch._dynamo.testing import EagerAndRecordGraphs
1566
1567        def true_fn(x):
1568            return x.sin()
1569
1570        def false_fn(x):
1571            return x.cos()
1572
1573        x = torch.randn(4, requires_grad=True)
1574
1575        def f(pred, x):
1576            result = cond(pred, true_fn, false_fn, (x,))
1577            grad_out = torch.ones_like(result)
1578            return torch.autograd.grad(result, (x,), grad_out)
1579
1580        backend = EagerAndRecordGraphs()
1581        torch.compile(f, backend=backend)(torch.tensor(False), x)
1582        self.assertEqual(len(backend.graphs), 2)
1583        gm = backend.graphs[0]
1584
1585        self.assertExpectedInline(
1586            gm.code.strip(),
1587            """\
1588def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor):
1589    l_pred_ = L_pred_
1590    l_x_ = L_x_
1591    cond_true_0 = self.cond_true_0
1592    cond_false_0 = self.cond_false_0
1593    cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [l_x_]);  l_pred_ = cond_true_0 = cond_false_0 = l_x_ = None
1594    result = cond[0];  cond = None
1595    grad_out = torch.ones_like(result)
1596    return (result, grad_out)""",  # noqa: B950
1597        )
1598
1599        self.assertExpectedInline(
1600            gm.cond_true_0.code.strip(),
1601            """\
1602def forward(self, l_x_):
1603    l_x__1 = l_x_
1604    sin = l_x__1.sin();  l_x__1 = None
1605    return (sin,)""",  # noqa: B950
1606        )
1607        self.assertExpectedInline(
1608            gm.cond_false_0.code.strip(),
1609            """\
1610def forward(self, l_x_):
1611    l_x__1 = l_x_
1612    cos = l_x__1.cos();  l_x__1 = None
1613    return (cos,)""",  # noqa: B950
1614        )
1615
1616        backward_gm = backend.graphs[1]
1617        self.assertExpectedInline(
1618            backward_gm.code.strip(),
1619            """\
1620def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tensor, L_flat_grads_0_ : torch.Tensor):
1621    l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
1622    l_ctx_pred = L_ctx_pred
1623    l_flat_grads_0_ = L_flat_grads_0_
1624    cond_true_0 = self.cond_true_0
1625    cond_false_0 = self.cond_false_0
1626    cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_ctx_saved_tensors_0_, l_flat_grads_0_]);  l_ctx_pred = cond_true_0 = cond_false_0 = l_ctx_saved_tensors_0_ = l_flat_grads_0_ = None
1627    getitem = cond[0];  cond = None
1628    return (getitem,)""",  # noqa: B950
1629        )
1630
1631    def test_while_loop_nested_traced(self):
1632        fn, inp = WHILE_LOOP_TESTS["nested"]
1633        graphs = self._check_tracing(fn, inp)
1634        self.assertExpectedInline(
1635            graphs["symbolic"].code.strip("\n"),
1636            """\
1637def forward(self, out_iter_1, it_1, y_1):
1638    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
1639    while_loop_body_graph_0 = self.while_loop_body_graph_0
1640    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (out_iter_1, it_1, y_1), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = out_iter_1 = it_1 = y_1 = None
1641    getitem = while_loop[0]
1642    getitem_1 = while_loop[1]
1643    getitem_2 = while_loop[2];  while_loop = None
1644    return (getitem, getitem_1, getitem_2)
1645    """,  # noqa: B950
1646        )
1647        self.assertExpectedInline(
1648            graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
1649            """\
1650def forward(self, arg0_1, arg1_1, arg2_1):
1651    sum_1 = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
1652    lt = torch.ops.aten.lt.Scalar(sum_1, 2);  sum_1 = None
1653    return lt
1654    """,
1655        )
1656        self.assertExpectedInline(
1657            graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
1658            """\
1659def forward(self, arg0_1, arg1_1, arg2_1):
1660    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
1661    while_loop_body_graph_0 = self.while_loop_body_graph_0
1662    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = None
1663    getitem = while_loop[0]
1664    getitem_1 = while_loop[1]
1665    getitem_2 = while_loop[2];  while_loop = None
1666    add = torch.ops.aten.add.Tensor(getitem, 1);  getitem = None
1667    return (add, getitem_1, getitem_2)
1668    """,  # noqa: B950
1669        )
1670
1671    def _wrap_with_functionalize(self, fn, func_type):
1672        mode = None
1673        if func_type == "cpp":
1674            fn = CppFunctionalizeAPI().functionalize(fn)
1675        elif func_type == "python":
1676            fn = PythonFunctionalizeAPI().functionalize(fn)
1677            mode = FunctionalTensorMode()
1678        elif func_type == "functorch":
1679            fn = torch.func.functionalize(fn)
1680        else:
1681            assert func_type == "no"
1682        return fn, mode
1683
1684    @parametrize("func_type", ["no", "cpp", "python", "functorch"])
1685    def test_while_loop_simple_functionalize_check_graph(self, func_type):
1686        fn, inp = WHILE_LOOP_TESTS["simple_with_mutation"]
1687        fn, mode = self._wrap_with_functionalize(fn, func_type)
1688        mode = mode if mode is not None else contextlib.nullcontext()
1689        with mode:
1690            graphs = self._check_tracing(fn, inp)
1691        if func_type == "no":
1692            self.assertExpectedInline(
1693                graphs["symbolic"].code.strip("\n"),
1694                """\
1695def forward(self, x_1):
1696    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
1697    while_loop_body_graph_0 = self.while_loop_body_graph_0
1698    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None
1699    getitem = while_loop[0];  while_loop = None
1700    return (getitem,)
1701    """,  # noqa: B950
1702            )
1703            self.assertExpectedInline(
1704                graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
1705                """\
1706def forward(self, arg0_1):
1707    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1708    add_ = torch.ops.aten.add_.Tensor(clone, 1);  clone = None
1709    add__1 = torch.ops.aten.add_.Tensor(add_, -1);  add_ = None
1710    sum_1 = torch.ops.aten.sum.default(add__1);  add__1 = None
1711    lt = torch.ops.aten.lt.Scalar(sum_1, 10);  sum_1 = None
1712    return lt
1713    """,
1714            )
1715            self.assertExpectedInline(
1716                graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
1717                """\
1718def forward(self, arg0_1):
1719    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1720    add_ = torch.ops.aten.add_.Tensor(clone, 1);  clone = None
1721    add__1 = torch.ops.aten.add_.Tensor(add_, -1);  add_ = None
1722    add = torch.ops.aten.add.Tensor(add__1, 1);  add__1 = None
1723    return (add,)
1724    """,
1725            )
1726        elif func_type == "python":
1727            self.assertExpectedInline(
1728                graphs["symbolic"].code.strip("\n"),
1729                """\
1730def forward(self, arg0_1):
1731    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
1732    while_loop_body_graph_0 = self.while_loop_body_graph_0
1733    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1,), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = None
1734    getitem = while_loop[0];  while_loop = None
1735    return (getitem,)
1736    """,  # noqa: B950
1737            )
1738            self.assertExpectedInline(
1739                graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
1740                """\
1741def forward(self, arg0_1):
1742    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1743    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
1744    add_1 = torch.ops.aten.add.Tensor(add, -1);  add = None
1745    sum_1 = torch.ops.aten.sum.default(add_1);  add_1 = None
1746    lt = torch.ops.aten.lt.Scalar(sum_1, 10);  sum_1 = None
1747    return lt
1748    """,
1749            )
1750            self.assertExpectedInline(
1751                graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
1752                """\
1753def forward(self, arg0_1):
1754    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1755    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
1756    add_1 = torch.ops.aten.add.Tensor(add, -1);  add = None
1757    add_2 = torch.ops.aten.add.Tensor(add_1, 1);  add_1 = None
1758    return (add_2,)
1759    """,
1760            )
1761        else:
1762            self.assertExpectedInline(
1763                graphs["symbolic"].code.strip("\n"),
1764                """\
1765def forward(self, x_1):
1766    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
1767    while_loop_body_graph_0 = self.while_loop_body_graph_0
1768    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None
1769    getitem = while_loop[0];  while_loop = None
1770    return (getitem,)
1771    """,  # noqa: B950
1772            )
1773            self.assertExpectedInline(
1774                graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"),
1775                """\
1776def forward(self, arg0_1):
1777    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1778    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
1779    add_1 = torch.ops.aten.add.Tensor(add, -1);  add = None
1780    sum_1 = torch.ops.aten.sum.default(add_1);  add_1 = None
1781    lt = torch.ops.aten.lt.Scalar(sum_1, 10);  sum_1 = None
1782    return lt
1783    """,
1784            )
1785            self.assertExpectedInline(
1786                graphs["symbolic"].while_loop_body_graph_0.code.strip("\n"),
1787                """\
1788def forward(self, arg0_1):
1789    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1790    add = torch.ops.aten.add.Tensor(clone, 1);  clone = None
1791    add_1 = torch.ops.aten.add.Tensor(add, -1);  add = None
1792    add_2 = torch.ops.aten.add.Tensor(add_1, 1);  add_1 = None
1793    return (add_2,)
1794    """,
1795            )
1796
1797    @parametrize("func_type", ["no", "cpp", "python", "functorch"])
1798    @parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys()))
1799    def test_while_loop_functionalize(self, func_type, while_loop_test):
1800        # simple_with_linear doesn't work becaue parameters and buffers
1801        # are not inputs so they're not wrapped by functionalization and tracing.
1802        if while_loop_test not in ("simple_with_linear", "nested_with_linear"):
1803            fn, inp = WHILE_LOOP_TESTS[while_loop_test]
1804            fn, mode = self._wrap_with_functionalize(fn, func_type)
1805            mode = mode if mode is not None else contextlib.nullcontext()
1806            with mode:
1807                self._check_tracing(fn, inp)
1808
1809    @parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys()))
1810    def test_while_loop_tracing(self, while_loop_test):
1811        fn, inp = WHILE_LOOP_TESTS[while_loop_test]
1812        allow_non_fake_inputs = (
1813            False
1814            if while_loop_test not in ("simple_with_linear", "nested_with_linear")
1815            else True
1816        )
1817        self._check_tracing(fn, inp, allow_non_fake_inputs)
1818
1819    @parametrize("backend", ["eager", "aot_eager"])
1820    @parametrize("while_loop_test", list(WHILE_LOOP_TESTS.keys()))
1821    def test_while_loop_compile(self, backend, while_loop_test):
1822        fn, inp = WHILE_LOOP_TESTS[while_loop_test]
1823        self._check_compile(fn, inp, backend=backend)
1824
1825    @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
1826    def test_while_loop_simple_with_linear_compile_check_graph(self):
1827        fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
1828        from torch._dynamo.testing import EagerAndRecordGraphs
1829
1830        backend = EagerAndRecordGraphs()
1831        torch.compile(fn, backend=backend)(*inp)
1832        self.assertEqual(len(backend.graphs), 1)
1833        gm = backend.graphs[0]
1834        if torch._dynamo.config.inline_inbuilt_nn_modules:
1835            self.assertExpectedInline(
1836                gm.code.strip(),
1837                """\
1838def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor, L_self_buffers_dec_ : torch.Tensor, L_self_modules_linear_parameters_weight_ : torch.nn.parameter.Parameter, L_self_modules_linear_parameters_bias_ : torch.nn.parameter.Parameter):
1839    l_iter_ = L_iter_
1840    l_x_ = L_x_
1841    l_self_buffers_dec_ = L_self_buffers_dec_
1842    l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_
1843    l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
1844    cond_fn_0 = self.cond_fn_0
1845    body_fn_0 = self.body_fn_0
1846    while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l_self_buffers_dec_, l_self_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_));  cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l_self_buffers_dec_ = l_self_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = None
1847    getitem = while_loop[0]
1848    getitem_1 = while_loop[1];  while_loop = None
1849    return (getitem, getitem_1)""",  # noqa: B950
1850            )
1851            self.assertExpectedInline(
1852                gm.cond_fn_0.code.strip(),
1853                """\
1854def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
1855    sub = l_iter_ - l_self_buffers_dec__cond_fn;  l_iter_ = l_self_buffers_dec__cond_fn = None
1856    gt = sub > 0;  sub = None
1857    return gt""",  # noqa: B950
1858            )
1859            self.assertExpectedInline(
1860                gm.body_fn_0.code.strip(),
1861                """\
1862def forward(self, l_iter_, l_x_, l_self_buffers_dec__cond_fn, l_self_modules_linear_parameters_bias__body_fn, l_self_modules_linear_parameters_weight__body_fn):
1863    child = l_iter_ - 1;  l_iter_ = None
1864    child_1 = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn);  l_x_ = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None
1865    return (child, child_1)""",  # noqa: B950
1866            )
1867        else:
1868            self.assertExpectedInline(
1869                gm.code.strip(),
1870                """\
1871def forward(self, L_iter_ : torch.Tensor, L_x_ : torch.Tensor):
1872    l_iter_ = L_iter_
1873    l_x_ = L_x_
1874    l__self___dec = self.L__self___dec
1875    l__self___linear_weight = self.L__self___linear_weight
1876    l__self___linear_bias = self.L__self___linear_bias
1877    cond_fn_0 = self.cond_fn_0
1878    body_fn_0 = self.body_fn_0
1879    while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_iter_, l_x_), (l__self___dec, l__self___linear_bias, l__self___linear_weight));  cond_fn_0 = body_fn_0 = l_iter_ = l_x_ = l__self___dec = l__self___linear_bias = l__self___linear_weight = None
1880    getitem = while_loop[0]
1881    getitem_1 = while_loop[1];  while_loop = None
1882    return (getitem, getitem_1)""",  # noqa: B950
1883            )
1884            self.assertExpectedInline(
1885                gm.cond_fn_0.code.strip(),
1886                """\
1887def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
1888    sub = l_iter_ - l__self___dec_cond_fn;  l_iter_ = l__self___dec_cond_fn = None
1889    gt = sub > 0;  sub = None
1890    return gt""",  # noqa: B950
1891            )
1892            self.assertExpectedInline(
1893                gm.body_fn_0.code.strip(),
1894                """\
1895def forward(self, l_iter_, l_x_, l__self___dec_cond_fn, l__self___linear_bias_body_fn, l__self___linear_weight_body_fn):
1896    child = l_iter_ - 1;  l_iter_ = None
1897    child_1 = torch._C._nn.linear(l_x_, l__self___linear_weight_body_fn, l__self___linear_bias_body_fn);  l_x_ = l__self___linear_weight_body_fn = l__self___linear_bias_body_fn = None
1898    return (child, child_1)""",  # noqa: B950
1899            )
1900
1901    def test_while_loop_nested2_traced(self):
1902        fn, inp = WHILE_LOOP_TESTS["nested2"]
1903        graphs = self._check_tracing(fn, inp)
1904        gm = graphs["symbolic"]
1905        outer_body = gm.while_loop_body_graph_0
1906        outer_cond = gm.while_loop_cond_graph_0
1907        inner_body = outer_body.while_loop_body_graph_0
1908        inner_cond = outer_body.while_loop_cond_graph_0
1909        self.assertExpectedInline(
1910            gm.code.strip("\n"),
1911            """\
1912def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
1913    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
1914    while_loop_body_graph_0 = self.while_loop_body_graph_0
1915    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
1916    getitem = while_loop[0]
1917    getitem_1 = while_loop[1]
1918    getitem_2 = while_loop[2]
1919    getitem_3 = while_loop[3];  while_loop = None
1920    return (getitem, getitem_1, getitem_2, getitem_3)
1921    """,  # noqa: B950
1922        )
1923        self.assertExpectedInline(
1924            outer_body.code.strip("\n"),
1925            """\
1926def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
1927    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
1928    while_loop_body_graph_0 = self.while_loop_body_graph_0
1929    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
1930    getitem = while_loop[0]
1931    getitem_1 = while_loop[1]
1932    getitem_2 = while_loop[2]
1933    getitem_3 = while_loop[3];  while_loop = None
1934    sub = torch.ops.aten.sub.Tensor(getitem, 1);  getitem = None
1935    clone = torch.ops.aten.clone.default(getitem_1);  getitem_1 = None
1936    mul = torch.ops.aten.mul.Tensor(getitem_2, 2);  getitem_2 = None
1937    div = torch.ops.aten.div.Tensor(getitem_3, 2);  getitem_3 = None
1938    return (sub, clone, mul, div)
1939    """,  # noqa: B950
1940        )
1941        self.assertExpectedInline(
1942            outer_body.code.strip("\n"),
1943            """\
1944def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
1945    while_loop_cond_graph_0 = self.while_loop_cond_graph_0
1946    while_loop_body_graph_0 = self.while_loop_body_graph_0
1947    while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1, arg1_1, arg2_1, arg3_1), ());  while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
1948    getitem = while_loop[0]
1949    getitem_1 = while_loop[1]
1950    getitem_2 = while_loop[2]
1951    getitem_3 = while_loop[3];  while_loop = None
1952    sub = torch.ops.aten.sub.Tensor(getitem, 1);  getitem = None
1953    clone = torch.ops.aten.clone.default(getitem_1);  getitem_1 = None
1954    mul = torch.ops.aten.mul.Tensor(getitem_2, 2);  getitem_2 = None
1955    div = torch.ops.aten.div.Tensor(getitem_3, 2);  getitem_3 = None
1956    return (sub, clone, mul, div)
1957    """,  # noqa: B950
1958        )
1959        self.assertExpectedInline(
1960            inner_body.code.strip("\n"),
1961            """\
1962def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
1963    clone = torch.ops.aten.clone.default(arg0_1);  arg0_1 = None
1964    sub = torch.ops.aten.sub.Tensor(arg1_1, 1);  arg1_1 = None
1965    add = torch.ops.aten.add.Tensor(arg2_1, 3.14);  arg2_1 = None
1966    sub_1 = torch.ops.aten.sub.Tensor(arg3_1, 2.71);  arg3_1 = None
1967    return (clone, sub, add, sub_1)
1968    """,
1969        )
1970        self.assertExpectedInline(
1971            inner_cond.code.strip("\n"),
1972            """\
1973def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
1974    gt = torch.ops.aten.gt.Scalar(arg1_1, 0);  arg1_1 = None
1975    return gt
1976    """,
1977        )
1978
1979    def test_cond_nested_traced(self):
1980        def true_nested(y):
1981            return y * y
1982
1983        def false_nested(y):
1984            return y + y
1985
1986        def true_fn(x, pred2):
1987            z = cond(pred2, true_nested, false_nested, [x])
1988            return x + z
1989
1990        def false_fn(x, _):
1991            return x.cos()
1992
1993        def f(x, pred, pred2):
1994            return cond(pred, true_fn, false_fn, [x, pred2])
1995
1996        x = torch.randn(4)
1997        graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
1998
1999        result_true_true = graph.forward(
2000            x, torch.tensor(True), torch.tensor(True)
2001        )  # True + True -> x * x
2002        result_true_false = graph.forward(
2003            x, torch.tensor(True), torch.tensor(False)
2004        )  # True + True -> x + x
2005        result_false_true = graph.forward(
2006            x, torch.tensor(False), torch.tensor(True)
2007        )  # False + either -> cos
2008        result_false_false = graph.forward(
2009            x, torch.tensor(False), torch.tensor(False)
2010        )  # False + either -> cos
2011
2012        self.assertNotEqual(result_true_true, result_true_false)
2013        self.assertFalse(torch.allclose(result_false_true, result_true_true))
2014
2015        self.assertEqual(result_false_true, result_false_false)
2016
2017        self.assertEqual(result_true_true, (x * x) + x)
2018        self.assertEqual(result_true_false, x + x + x)
2019
2020        self.assertEqual(result_false_true, torch.cos(x))
2021
2022        graph = make_fx(f, tracing_mode="symbolic")(
2023            x, torch.tensor(False), torch.tensor(False)
2024        )
2025        self.assertEqual(
2026            graph(x, torch.tensor(True), torch.tensor(True)),
2027            f(x, torch.tensor(True), torch.tensor(True)),
2028        )
2029
2030    def test_cond_functionalized(self):
2031        def true_fn(x):
2032            y = x.sin()
2033            y.add_(4)
2034            return x.sin().max() + y.sum()
2035
2036        def false_fn(x):
2037            return x.cos().min()
2038
2039        def f(x):
2040            pred = x.shape[0] == 1
2041            return cond(pred, true_fn, false_fn, [x])
2042
2043        example_inputs = (torch.ones(4, 5),)
2044        functional_f = torch.func.functionalize(f)
2045        self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
2046
2047        graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
2048            *example_inputs
2049        )
2050        self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
2051
2052        all_ops_in_true_branch = []
2053        for node in graph_module.true_graph_0.graph.nodes:
2054            if node.op == "call_function":
2055                all_ops_in_true_branch.append(node.target)
2056
2057        self.assertFalse(any(op._schema.is_mutable for op in all_ops_in_true_branch))
2058
2059        self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
2060
2061    def test_cond_accepts_torch_function_as_inputs(self):
2062        a = torch.randn(3, 4)
2063        b = torch.randn(3, 4)
2064
2065        def f(a, b):
2066            return cond(a.sum() > 0, torch.add, torch.mul, (a, b))
2067
2068        gm = self._check_tracing(f, (a, b))["symbolic"]
2069        self.assertExpectedInline(
2070            gm.code.strip(),
2071            """\
2072def forward(self, a_1, b_1):
2073    sum_1 = torch.ops.aten.sum.default(a_1)
2074    gt = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
2075    true_graph_0 = self.true_graph_0
2076    false_graph_0 = self.false_graph_0
2077    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1]);  gt = true_graph_0 = false_graph_0 = a_1 = b_1 = None
2078    getitem = cond[0];  cond = None
2079    return getitem""",  # noqa: B950
2080        )
2081        self.assertExpectedInline(
2082            gm.true_graph_0.code.strip(),
2083            """\
2084def forward(self, arg0_1, arg1_1):
2085    add = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
2086    return (add,)""",
2087        )
2088        self.assertExpectedInline(
2089            gm.false_graph_0.code.strip(),
2090            """\
2091def forward(self, arg0_1, arg1_1):
2092    mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
2093    return (mul,)""",
2094        )
2095
2096    def test_cond_retrace_functionalized(self):
2097        def true_fn(x):
2098            return x.sin()
2099
2100        def false_fn(x):
2101            return x.cos()
2102
2103        def f(x):
2104            return cond(x.all(), true_fn, false_fn, (x,))
2105
2106        inp = torch.ones(1, 2)
2107        gm_non_functional = make_fx(f, tracing_mode="real")(inp)
2108        gm_functional = make_fx(
2109            torch.func.functionalize(gm_non_functional), tracing_mode="real"
2110        )(inp)
2111        self.assertEqual(gm_functional(torch.zeros(1, 2)), f(torch.zeros(1, 2)))
2112
2113    def test_cond_subgraph_same_shape_env_as_parent(self):
2114        def true_fn(x):
2115            return x.sin() + 10
2116
2117        def false_fn(x):
2118            return x.cos() - 20
2119
2120        def f(x, pred):
2121            y = cond(pred, true_fn, false_fn, [x])
2122            z = torch.add(y, y)
2123            return z
2124
2125        symbolic_traced_graph = self._check_tracing(
2126            f, (torch.ones(4), torch.Tensor([True]))
2127        )["symbolic"]
2128        graph_shape_env = symbolic_traced_graph.shape_env
2129
2130        def _node_shape_env_iter(gm):
2131            for node in symbolic_traced_graph.graph.nodes:
2132                if node.op == "call_function":
2133                    val = node.meta.get("val")
2134                    if isinstance(val, tuple):
2135                        for v in val:
2136                            yield v.fake_mode.shape_env
2137                    else:
2138                        yield val.fake_mode.shape_env
2139
2140        for shape_env in _node_shape_env_iter(symbolic_traced_graph):
2141            self.assertTrue(shape_env is graph_shape_env)
2142
2143        for shape_env in _node_shape_env_iter(symbolic_traced_graph.true_graph_0):
2144            self.assertTrue(shape_env is graph_shape_env)
2145
2146        for shape_env in _node_shape_env_iter(symbolic_traced_graph.false_graph_0):
2147            self.assertTrue(shape_env is graph_shape_env)
2148
2149    def test_cond_functionalized_nested(self):
2150        def true_true_fn(x):
2151            y = x.cos()
2152            y.add_(4)
2153            return x.sin().max() + y.sin().max()
2154
2155        def true_false_fn(x):
2156            return x.cos().min()
2157
2158        def true_fn(x):
2159            pred = x.shape[0] == 1
2160            return cond(pred, true_true_fn, true_false_fn, [x])
2161
2162        def false_fn(x):
2163            return x.sum()
2164
2165        def f(x):
2166            pred = x.shape[0] == 1
2167            return cond(pred, true_fn, false_fn, [x])
2168
2169        example_inputs = (torch.ones(4, 5),)
2170        functional_f = torch.func.functionalize(f)
2171        self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
2172
2173        graph_module = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
2174            *example_inputs
2175        )
2176        self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
2177
2178        gm_true_true_branch = graph_module.true_graph_0.true_graph_0
2179
2180        self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
2181
2182        all_ops = []
2183        for node in gm_true_true_branch.graph.nodes:
2184            if node.op == "call_function":
2185                all_ops.append(node.target)
2186
2187        self.assertFalse(any(op._schema.is_mutable for op in all_ops))
2188
2189    def test_cond_functionalized_data_dependent_pred(self):
2190        def true_fn(x):
2191            return x.sin().sum()
2192
2193        def false_fn(x):
2194            return x.cos().sum()
2195
2196        def f(x):
2197            pred = x.nonzero().shape[0] == 1
2198            return cond(pred, true_fn, false_fn, [x])
2199
2200        example_inputs = (torch.ones(4, 5),)
2201        functional_f = torch.func.functionalize(f)
2202        self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
2203
2204        graph_module = make_fx(torch.func.functionalize(f))(*example_inputs)
2205        self.assertEqual(graph_module(*example_inputs), f(*example_inputs))
2206
2207    # https://github.com/pytorch/pytorch/issues/126988
2208    def test_cond_functionalized_input_mutation_on_true_brancte(self):
2209        def true_fn(x):
2210            view_x = x.view(x.shape)
2211            view_x.add_(1)
2212            return view_x.sin().sum()
2213
2214        def false_fn(x):
2215            return x.cos().sum()
2216
2217        def f(x):
2218            pred = x.shape[0] == 4
2219            return cond(pred, true_fn, false_fn, [x])
2220
2221        example_inputs = (torch.ones(4, 5),)
2222        # torch.cond inlines into one of the branches because the predicate
2223        # is a constant.
2224        gm = make_fx(torch.func.functionalize(f))(*example_inputs)
2225        self.assertExpectedInline(
2226            gm.code.strip(),
2227            """\
2228def forward(self, x_1):
2229    view = torch.ops.aten.view.default(x_1, [4, 5])
2230    add = torch.ops.aten.add.Tensor(view, 1);  view = None
2231    view_1 = torch.ops.aten.view.default(add, [4, 5]);  add = None
2232    view_2 = torch.ops.aten.view.default(view_1, [4, 5])
2233    sin = torch.ops.aten.sin.default(view_2);  view_2 = None
2234    sum_1 = torch.ops.aten.sum.default(sin);  sin = None
2235    copy_ = torch.ops.aten.copy_.default(x_1, view_1);  x_1 = view_1 = copy_ = None
2236    return sum_1""",
2237        )
2238
2239        # torch.cond triggers the check of the branches because the predicate
2240        # is a SymBool.
2241        with self.assertRaisesRegex(
2242            UnsupportedAliasMutationException, "One of torch.cond branch"
2243        ):
2244            make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
2245                *example_inputs
2246            )
2247
2248    # https://github.com/pytorch/pytorch/issues/126988
2249    def test_cond_functionalized_input_mutation_on_false_branch(self):
2250        def true_fn(x):
2251            return x.sin().sum()
2252
2253        def false_fn(x):
2254            view_x = x.view(x.shape)
2255            view_x.add_(1)
2256            return view_x.cos().sum()
2257
2258        def f(x):
2259            pred = x.shape[0] == 4
2260            return cond(pred, true_fn, false_fn, [x])
2261
2262        example_inputs = (torch.ones(5, 5),)
2263        gm = make_fx(torch.func.functionalize(f))(*example_inputs)
2264        # torch.cond inlines into one of the branches because the predicate
2265        # is a constant.
2266        self.assertExpectedInline(
2267            gm.code.strip(),
2268            """\
2269def forward(self, x_1):
2270    view = torch.ops.aten.view.default(x_1, [5, 5])
2271    add = torch.ops.aten.add.Tensor(view, 1);  view = None
2272    view_1 = torch.ops.aten.view.default(add, [5, 5]);  add = None
2273    view_2 = torch.ops.aten.view.default(view_1, [5, 5])
2274    cos = torch.ops.aten.cos.default(view_2);  view_2 = None
2275    sum_1 = torch.ops.aten.sum.default(cos);  cos = None
2276    copy_ = torch.ops.aten.copy_.default(x_1, view_1);  x_1 = view_1 = copy_ = None
2277    return sum_1""",
2278        )
2279
2280        # torch.cond triggers the check of the branches because the predicate
2281        # is a SymBool.
2282        with self.assertRaisesRegex(
2283            UnsupportedAliasMutationException, "One of torch.cond branch"
2284        ):
2285            make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
2286                *example_inputs
2287            )
2288
2289    # https://github.com/pytorch/pytorch/issues/126988
2290    def test_cond_functionalized_output_alias_input(self):
2291        def true_fn(x):
2292            return x
2293
2294        def false_fn(x):
2295            view_x = x.view(x.shape)
2296            return view_x
2297
2298        def f(x):
2299            pred = x.shape[0] == 4
2300            return cond(pred, true_fn, false_fn, [x])
2301
2302        example_inputs = (torch.ones(5, 5),)
2303        gm = make_fx(torch.func.functionalize(f))(*example_inputs)
2304        # torch.cond inlines into one of the branches because the predicate
2305        # is a constant.
2306        self.assertExpectedInline(
2307            gm.code.strip(),
2308            """\
2309def forward(self, x_1):
2310    view = torch.ops.aten.view.default(x_1, [5, 5]);  x_1 = None
2311    return view""",
2312        )
2313
2314        # torch.cond triggers the check of the branches because the predicate
2315        # is a SymBool.
2316        with self.assertRaisesRegex(
2317            UnsupportedAliasMutationException, "One of torch.cond branch"
2318        ):
2319            make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
2320                *example_inputs
2321            )
2322
2323    # https://github.com/pytorch/pytorch/issues/126988
2324    def test_cond_functionalized_nested_input_mutation(self):
2325        def true_true_fn(x):
2326            x.add_(4)
2327            return x.sin().max()
2328
2329        def true_false_fn(x):
2330            return x.cos().min()
2331
2332        def true_fn(x):
2333            pred = x.shape[0] == 1
2334            return cond(pred, true_true_fn, true_false_fn, [x])
2335
2336        def false_fn(x):
2337            return x.sum()
2338
2339        def f(x):
2340            pred = x.shape[0] == 1
2341            return cond(pred, true_fn, false_fn, [x])
2342
2343        example_inputs = (torch.ones(4, 5),)
2344        with self.assertRaisesRegex(
2345            UnsupportedAliasMutationException, "One of torch.cond branch"
2346        ):
2347            make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
2348                *example_inputs
2349            )
2350
2351    # https://github.com/pytorch/pytorch/issues/126988
2352    def test_cond_functionalized_nested_input_mutation_with_aot_func(self):
2353        def true_true_fn(x):
2354            x.add_(4)
2355            return x.sin().max()
2356
2357        def true_false_fn(x):
2358            return x.cos().min()
2359
2360        def true_fn(x):
2361            pred = x.shape[0] == 1
2362            return cond(pred, true_true_fn, true_false_fn, [x])
2363
2364        def false_fn(x):
2365            return x.sum()
2366
2367        def f(x):
2368            pred = x.shape[0] == 1
2369            return cond(pred, true_fn, false_fn, [x])
2370
2371        example_input = torch.ones(4, 5)
2372        try:
2373            example_input_func = to_fun_old(example_input)
2374            torch._enable_functionalization(reapply_views=False)
2375            f(example_input_func)
2376
2377            with self.assertRaisesRegex(
2378                UnsupportedAliasMutationException, "One of torch.cond branch"
2379            ):
2380                make_fx(f, tracing_mode="symbolic")(example_input_func)
2381        finally:
2382            torch._disable_functionalization()
2383
2384        def f_wrapper(func):
2385            @functools.wraps(func)
2386            def wrapper(*args, **kwargs):
2387                torch._enable_functionalization(reapply_views=False)
2388                try:
2389                    return func(*args, **kwargs)
2390                finally:
2391                    torch._disable_functionalization()
2392
2393            return wrapper
2394
2395        with self.assertRaisesRegex(
2396            UnsupportedAliasMutationException, "One of torch.cond branch"
2397        ):
2398            make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input_func)
2399
2400    # https://github.com/pytorch/pytorch/issues/126988
2401    @xfailIfTorchDynamo
2402    def test_cond_functionalized_input_aliasing_with_aot_func(self):
2403        def true_fn(x):
2404            return x
2405
2406        def false_fn(x):
2407            view_x = x.view(x.shape)
2408            return view_x
2409
2410        def f(x):
2411            pred = x.sum() > 0
2412            return cond(pred, true_fn, false_fn, [x])
2413
2414        example_input = torch.ones(5, 5)
2415        try:
2416            example_input_func = to_fun_old(example_input)
2417            torch._enable_functionalization(reapply_views=False)
2418            with self.assertRaisesRegex(
2419                UnsupportedAliasMutationException,
2420                "One of torch.cond branch might be aliasing",
2421            ):
2422                f(example_input_func)
2423        finally:
2424            torch._disable_functionalization()
2425
2426        def f_wrapper(func):
2427            @functools.wraps(func)
2428            def wrapper(*args, **kwargs):
2429                torch._enable_functionalization(reapply_views=False)
2430                try:
2431                    func_args = pytree.tree_map(
2432                        lambda x: torch._to_functional_tensor(x)
2433                        if isinstance(x, torch.Tensor)
2434                        else x,
2435                        args,
2436                    )
2437                    func_kwargs = pytree.tree_map(
2438                        lambda x: torch._to_functional_tensor(x)
2439                        if isinstance(x, torch.Tensor)
2440                        else x,
2441                        kwargs,
2442                    )
2443                    return func(*func_args, **func_kwargs)
2444                finally:
2445                    torch._disable_functionalization()
2446
2447            return wrapper
2448
2449        with self.assertRaisesRegex(
2450            UnsupportedAliasMutationException,
2451            "One of torch.cond branch might be aliasing",
2452        ):
2453            make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
2454
2455    def test_cond_functionalized_aot_func_check_functional(self):
2456        def true_fn(x):
2457            return x.cos()
2458
2459        def false_fn(x):
2460            y = x.sin()
2461            y.add_(5)
2462            return y
2463
2464        def f(x):
2465            pred = x.shape[0] == 4
2466            return cond(pred, true_fn, false_fn, [x])
2467
2468        example_input = torch.ones(5, 5)
2469
2470        def f_wrapper(func):
2471            @functools.wraps(func)
2472            def wrapper(*args, **kwargs):
2473                torch._enable_functionalization(reapply_views=False)
2474                try:
2475                    func_args = pytree.tree_map(
2476                        lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
2477                        args,
2478                    )
2479                    func_kwargs = pytree.tree_map(
2480                        lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
2481                        kwargs,
2482                    )
2483                    return pytree.tree_map(
2484                        from_fun_old, func(*func_args, **func_kwargs)
2485                    )
2486                finally:
2487                    torch._disable_functionalization()
2488
2489            return wrapper
2490
2491        result_gm = make_fx(f_wrapper(f), tracing_mode="symbolic")(example_input)
2492        for node in result_gm.true_graph_0.graph.nodes:
2493            if node.op == "call_function":
2494                self.assertTrue(not node.target._schema.is_mutable)
2495
2496        for node in result_gm.false_graph_0.graph.nodes:
2497            if node.op == "call_function":
2498                self.assertTrue(not node.target._schema.is_mutable)
2499
2500        self.assertEqual(result_gm(torch.ones(5, 5)), f(torch.ones(5, 5)))
2501
2502    def test_cond_nested_traced_other_inputs(self):
2503        def true_nested(y):
2504            return y * y
2505
2506        def false_nested(y):
2507            return y + y
2508
2509        def true_fn(k, pred2):
2510            z = cond(pred2, true_nested, false_nested, [k])
2511            return torch.add(torch.tensor([0.25, 0.25]), z)
2512
2513        def false_fn(k, _):
2514            return k.cos()
2515
2516        def f(k, pred, pred2):
2517            return cond(pred, true_fn, false_fn, [k, pred2])
2518
2519        x = torch.tensor([0.5, 0.5])
2520        graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
2521
2522        a = torch.tensor([1.0, 1.0])
2523        result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
2524        self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
2525
2526        b = torch.tensor([2.0, 2.0])
2527        result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
2528        self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
2529
2530    def test_cond_nested_traced_multi(self):
2531        def true_a(y):
2532            return y * y
2533
2534        def false_a(y):
2535            return y + y
2536
2537        def true_b(y, z):
2538            return y + z
2539
2540        def false_b(y, z):
2541            return y * z
2542
2543        def f(x, pred, pred2):
2544            a_out = cond(pred, true_a, false_a, [x])
2545            b_out = cond(pred2, true_b, false_b, [x, x])
2546            return a_out + b_out
2547
2548        x = torch.randn(4)
2549        graph = make_fx(f)(x, torch.tensor(False), torch.tensor(False))
2550
2551        self.assertExpectedInline(
2552            graph.code.strip(),
2553            """\
2554def forward(self, x_1, pred_1, pred2_1):
2555    true_graph_0 = self.true_graph_0
2556    false_graph_0 = self.false_graph_0
2557    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]);  pred_1 = true_graph_0 = false_graph_0 = None
2558    getitem = cond[0];  cond = None
2559    true_graph_1 = self.true_graph_1
2560    false_graph_1 = self.false_graph_1
2561    cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]);  pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
2562    getitem_1 = cond_1[0];  cond_1 = None
2563    add = torch.ops.aten.add.Tensor(getitem, getitem_1);  getitem = getitem_1 = None
2564    return add""",  # noqa: B950
2565        )
2566        self.assertExpectedInline(
2567            graph.true_graph_0.code.strip(),
2568            """\
2569def forward(self, arg0_1):
2570    mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
2571    return (mul,)""",
2572        )
2573
2574    def test_raise_error_on_mismatch_type_size(self):
2575        def true_fn(x):
2576            return x.sin()
2577
2578        def false_fn(x):
2579            return (x, x)
2580
2581        def f(x, y):
2582            return cond(y, true_fn, false_fn, [x])
2583
2584        x = torch.randn(4)
2585        with self.assertRaisesRegex(
2586            torch._dynamo.exc.CondOpArgsMismatchError,
2587            "Expected to return same number of outputs but got:",
2588        ):
2589            make_fx(f)(x, torch.tensor(False))
2590
2591    def test_raise_error_on_mismatch_tensor_size(self):
2592        def true_fn(x):
2593            return x.sin()
2594
2595        def false_fn(x):
2596            return torch.zeros([10, 10])
2597
2598        def f(x, y):
2599            return cond(y, true_fn, false_fn, [x])
2600
2601        x = torch.randn(4)
2602        with self.assertRaisesRegex(
2603            torch._dynamo.exc.UncapturedHigherOrderOpError,
2604            "Cond doesn't work unless it is captured completely with torch.compile",
2605        ):
2606            make_fx(f)(x, torch.tensor(False))
2607
2608    def test_cond_traced_not_nested_fake_tensor(self):
2609        def true_fn(x):
2610            return x.sin()
2611
2612        def false_fn(x):
2613            return x.cos()
2614
2615        def f(x, y):
2616            return cond(y, true_fn, false_fn, [x])
2617
2618        x = torch.randn(4)
2619        graph = make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
2620        result_true = graph.forward(x, torch.tensor(True))
2621        result_false = graph.forward(x, torch.tensor(False))
2622        self.assertFalse(torch.allclose(result_true, result_false))
2623        self.assertEqual(result_true, torch.sin(x))
2624        self.assertEqual(result_false, torch.cos(x))
2625
2626    def test_cond_nested_traced_fake_tensor(self):
2627        def true_nested(y):
2628            return y * y
2629
2630        def false_nested(y):
2631            return y + y
2632
2633        def true_fn(x, pred2):
2634            z = cond(pred2, true_nested, false_nested, [x])
2635            return x + z
2636
2637        def false_fn(x, _):
2638            return x.cos()
2639
2640        def f(x, pred, pred2):
2641            return cond(pred, true_fn, false_fn, [x, pred2])
2642
2643        x = torch.randn(4)
2644        graph = make_fx(f, tracing_mode="fake")(
2645            x, torch.tensor(False), torch.tensor(False)
2646        )
2647
2648        result_true_true = graph.forward(
2649            x, torch.tensor(True), torch.tensor(True)
2650        )  # True + True -> x * x
2651        result_true_false = graph.forward(
2652            x, torch.tensor(True), torch.tensor(False)
2653        )  # True + True -> x + x
2654        result_false_true = graph.forward(
2655            x, torch.tensor(False), torch.tensor(True)
2656        )  # False + either -> cos
2657        result_false_false = graph.forward(
2658            x, torch.tensor(False), torch.tensor(False)
2659        )  # False + either -> cos
2660
2661        self.assertNotEqual(result_true_true, result_true_false)
2662        self.assertFalse(torch.allclose(result_false_true, result_true_true))
2663
2664        self.assertEqual(result_false_true, result_false_false)
2665
2666        self.assertEqual(result_true_true, (x * x) + x)
2667        self.assertEqual(result_true_false, x + x + x)
2668
2669        self.assertEqual(result_false_true, torch.cos(x))
2670
2671    def test_cond_nested_traced_other_inputs_fake_tensor(self):
2672        def true_nested(y):
2673            return y * y
2674
2675        def false_nested(y):
2676            return y + y
2677
2678        def true_fn(k, pred2):
2679            z = cond(pred2, true_nested, false_nested, [k])
2680            return torch.add(torch.tensor([0.25, 0.25]), z)
2681
2682        def false_fn(k, _):
2683            return k.cos()
2684
2685        def f(k, pred, pred2):
2686            return cond(pred, true_fn, false_fn, [k, pred2])
2687
2688        x = torch.tensor([0.5, 0.5])
2689        graph = make_fx(f, tracing_mode="fake")(
2690            x, torch.tensor(False), torch.tensor(False)
2691        )
2692
2693        a = torch.tensor([1.0, 1.0])
2694        result_true_true = graph.forward(a, torch.tensor(True), torch.tensor(True))
2695        self.assertEqual(result_true_true, (a * a) + torch.tensor([0.25, 0.25]))
2696
2697        b = torch.tensor([2.0, 2.0])
2698        result_true_true = graph.forward(b, torch.tensor(True), torch.tensor(True))
2699        self.assertEqual(result_true_true, (b * b) + torch.tensor([0.25, 0.25]))
2700
2701    def test_cond_nested_traced_multi_fake_tensor(self):
2702        def true_a(y):
2703            return y * y
2704
2705        def false_a(y):
2706            return y + y
2707
2708        def true_b(y, z):
2709            return y + z
2710
2711        def false_b(y, z):
2712            return y * z
2713
2714        def f(x, pred, pred2):
2715            a_out = cond(pred, true_a, false_a, [x])
2716            b_out = cond(pred2, true_b, false_b, [x, x])
2717            return a_out + b_out
2718
2719        x = torch.randn(4)
2720        graph = make_fx(f, tracing_mode="fake")(
2721            x, torch.tensor(False), torch.tensor(False)
2722        )
2723
2724        self.assertExpectedInline(
2725            graph.code.strip(),
2726            """\
2727def forward(self, x_1, pred_1, pred2_1):
2728    true_graph_0 = self.true_graph_0
2729    false_graph_0 = self.false_graph_0
2730    cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]);  pred_1 = true_graph_0 = false_graph_0 = None
2731    getitem = cond[0];  cond = None
2732    true_graph_1 = self.true_graph_1
2733    false_graph_1 = self.false_graph_1
2734    cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]);  pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
2735    getitem_1 = cond_1[0];  cond_1 = None
2736    add = torch.ops.aten.add.Tensor(getitem, getitem_1);  getitem = getitem_1 = None
2737    return add""",  # noqa: B950
2738        )
2739        self.assertExpectedInline(
2740            graph.true_graph_0.code.strip(),
2741            """\
2742def forward(self, arg0_1):
2743    mul = torch.ops.aten.mul.Tensor(arg0_1, arg0_1);  arg0_1 = None
2744    return (mul,)""",
2745        )
2746
2747    def test_raise_error_on_mismatch_type_size_fake_tensor(self):
2748        def true_fn(x):
2749            return x.sin()
2750
2751        def false_fn(x):
2752            return (x, x)
2753
2754        def f(x, y):
2755            return cond(y, true_fn, false_fn, [x])
2756
2757        x = torch.randn(4)
2758        with self.assertRaisesRegex(
2759            torch._dynamo.exc.CondOpArgsMismatchError,
2760            "Expected to return same number of outputs but got:",
2761        ):
2762            make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
2763
2764    def test_raise_error_on_mismatch_tensor_size_fake_tensor(self):
2765        def true_fn(x):
2766            return x.sin()
2767
2768        def false_fn(x):
2769            return torch.zeros([10, 10])
2770
2771        def f(x, y):
2772            return cond(y, true_fn, false_fn, [x])
2773
2774        x = torch.randn(4)
2775        with self.assertRaisesRegex(
2776            torch._dynamo.exc.UncapturedHigherOrderOpError,
2777            "Cond doesn't work unless it is captured completely with torch.compile",
2778        ):
2779            make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
2780
2781    def check_map_count(self, gm, op_count):
2782        i = 0
2783        for m in gm.modules():
2784            for node in m.graph.nodes:
2785                if (
2786                    node.op == "call_function"
2787                    and node.target == torch.ops.higher_order.map_impl
2788                ):
2789                    i += 1
2790        self.assertEqual(i, op_count)
2791
2792    def test_tracing_map_real(self):
2793        def f(x, y):
2794            return x + y
2795
2796        def g(xs, y):
2797            return control_flow.map(f, xs, y)
2798
2799        gm = make_fx(g, tracing_mode="real")(torch.ones(3, 2, 2), torch.ones(2))
2800        x = torch.randn(3, 2, 2)
2801        y = torch.randn(2)
2802        res = gm(x, y)
2803        self.assertEqual(res, g(x, y))
2804        self.check_map_count(gm, 1)
2805
2806    def test_tracing_map_symbolic_simple(self):
2807        def f(x, y):
2808            return x + y
2809
2810        def g(xs, y):
2811            return control_flow.map(f, xs, y)
2812
2813        gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 2, 4), torch.ones(4))
2814        x = torch.randn(3, 2, 2)
2815        y = torch.randn(2)
2816        res = gm(x, y)
2817        self.assertEqual(res, g(x, y))
2818        self.check_map_count(gm, 1)
2819
2820    def test_tracing_map_symbolic_list(self):
2821        def f(x, y):
2822            return [x[0][0] + y, x[1] * y]
2823
2824        def g(xs, y, z):
2825            out = control_flow.map(f, xs, y)
2826            return out[0] + z, out[1] * z
2827
2828        example_x = [[torch.ones(3, 4, 5)], torch.ones(3, 4, 5)]
2829        gm = make_fx(g, tracing_mode="symbolic")(
2830            example_x, torch.ones(5), torch.ones(5)
2831        )
2832        x = [[torch.randn(4, 5, 6)], torch.ones(4, 5, 6)]
2833        y = torch.randn(6)
2834        z = torch.ones(6)
2835        res = gm(x, y, z)
2836        self.assertEqual(res, g(x, y, z))
2837        self.check_map_count(gm, 1)
2838
2839    def test_tracing_map_symbolic_dict(self):
2840        def f(x, y):
2841            return {"d": x["b"]["a"] + y, "e": x["c"] * y}
2842
2843        def g(xs, y, z):
2844            out = control_flow.map(f, xs, y)
2845            return {"f": out["d"] + z, "g": out["e"] * z}
2846
2847        example_x = {"b": {"a": torch.ones(3, 4, 5)}, "c": torch.ones(3, 4, 5)}
2848        gm = make_fx(g, tracing_mode="symbolic")(
2849            example_x, torch.ones(5), torch.ones(5)
2850        )
2851        x = {"b": {"a": torch.randn(4, 5, 6)}, "c": torch.ones(4, 5, 6)}
2852        y = torch.randn(6)
2853        z = torch.ones(6)
2854        res = gm(x, y, z)
2855        self.assertEqual(res, g(x, y, z))
2856        self.check_map_count(gm, 1)
2857
2858    def test_tracing_map_autograd_symbolic_simple(self):
2859        def f(x, y):
2860            return x + y
2861
2862        def g(xs, y):
2863            out = control_flow.map(f, xs, y)
2864            return torch.autograd.grad(out, (xs, y), torch.ones_like(out))
2865
2866        gm = make_fx(g, tracing_mode="symbolic")(
2867            torch.ones(3, 4, 5, requires_grad=True), torch.ones(5, requires_grad=True)
2868        )
2869        x = torch.randn(4, 5, 6, requires_grad=True)
2870        y = torch.randn(6, requires_grad=True)
2871        res = gm(x, y)
2872        self.assertEqual(res, g(x, y))
2873        self.check_map_count(gm, 2)
2874
2875    def test_tracing_map_autograd_symbolic_list(self):
2876        import torch.utils._pytree as pytree
2877
2878        def f(x, y):
2879            return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]
2880
2881        def g(xs, y):
2882            out = control_flow.map(f, xs, y)
2883            flat_out = pytree.tree_leaves(out)
2884            flat_inp = pytree.tree_leaves((xs, y))
2885            requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
2886            return torch.autograd.grad(
2887                flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]
2888            )
2889
2890        gm = make_fx(g, tracing_mode="symbolic")(
2891            [torch.ones(3, 4, 5), torch.ones(3, 4, 5, requires_grad=True)],
2892            torch.ones(5, requires_grad=True),
2893        )
2894        x = [torch.randn(4, 5, 6), torch.ones(4, 5, 6, requires_grad=True)]
2895        y = torch.randn(6, requires_grad=True)
2896        res = gm(x, y)
2897        self.assertEqual(res, g(x, y))
2898        self.check_map_count(gm, 2)
2899
2900    def test_tracing_map_autograd_symbolic_dict(self):
2901        def f(x, y):
2902            return [x["a"] + y, x["b"] * y]
2903
2904        def g(xs, y):
2905            out = control_flow.map(f, xs, y)
2906            flat_out = pytree.tree_leaves(out)
2907            flat_inp = pytree.tree_leaves((xs, y))
2908            requires_grad_inp = [inp for inp in flat_inp if inp.requires_grad]
2909            return torch.autograd.grad(
2910                flat_out, requires_grad_inp, [torch.ones_like(out) for out in flat_out]
2911            )
2912
2913        traced_x = {
2914            "a": torch.ones(3, 4, 5, requires_grad=True),
2915            "b": torch.ones(3, 4, 5, requires_grad=True),
2916        }
2917        gm = make_fx(g, tracing_mode="symbolic")(
2918            traced_x, torch.ones(5, requires_grad=True)
2919        )
2920        x = {
2921            "a": torch.randn(4, 5, 6, requires_grad=True),
2922            "b": torch.ones(4, 5, 6, requires_grad=True),
2923        }
2924        y = torch.randn(6, requires_grad=True)
2925        res = gm(x, y)
2926        self.assertEqual(res, g(x, y))
2927        self.check_map_count(gm, 2)
2928
2929    def test_tracing_map_autograd_aot_functionalized(self):
2930        def inner(x, y):
2931            z = x - 1
2932            z.add_(1)
2933            return z * y
2934
2935        def f(xs, y):
2936            res = control_flow.map(inner, xs, y)
2937            grads = torch.autograd.grad(res, (xs, y), torch.ones_like(res))
2938            return grads
2939
2940        def f_wrapper(func):
2941            @functools.wraps(func)
2942            def wrapper(*args, **kwargs):
2943                torch._enable_functionalization(reapply_views=False)
2944                try:
2945                    return pytree.tree_map(from_fun_old, func(*args, **kwargs))
2946                finally:
2947                    torch._disable_functionalization()
2948
2949            return wrapper
2950
2951        example_inputs = (
2952            torch.ones(3, 2, 4, requires_grad=True),
2953            torch.ones(2, 4, requires_grad=True),
2954        )
2955        gm = make_fx(f, tracing_mode="symbolic")(*example_inputs)
2956        fgm = make_fx(f_wrapper(f), tracing_mode="symbolic")(*example_inputs)
2957        xs = torch.ones(3, 4, 5, requires_grad=True)
2958        y = torch.ones(4, 5, requires_grad=True)
2959
2960        self.assertEqual(gm(xs, y), f(xs, y))
2961
2962        def count_mutable(gm):
2963            c = 0
2964            for node in gm.graph.nodes:
2965                if node.op == "call_function":
2966                    if node.target == torch.ops.higher_order.map_impl:
2967                        c += count_mutable(getattr(gm, str(node.args[0])))
2968                    elif schema := getattr(node.target, "_schema", None):
2969                        c += int(schema.is_mutable)
2970            return c
2971
2972        self.assertEqual(count_mutable(fgm), 0)
2973        # One for forward, one for recomputation logic in backward
2974        self.assertEqual(count_mutable(gm), 2)
2975
2976    def test_map_functionalized(self):
2977        def map_fn(x, y):
2978            z = x + y
2979            z.add_(4)
2980            return z
2981
2982        def f(xs, y):
2983            return control_flow.map(map_fn, xs, y)
2984
2985        example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
2986        functional_f = torch.func.functionalize(f)
2987        self.assertEqual(functional_f(*example_inputs), f(*example_inputs))
2988
2989        gm = make_fx(torch.func.functionalize(f))(*example_inputs)
2990        self.assertEqual(gm(*example_inputs), f(*example_inputs))
2991
2992        gm = make_fx(torch.func.functionalize(f), tracing_mode="symbolic")(
2993            *example_inputs
2994        )
2995        self.assertEqual(gm(*example_inputs), f(*example_inputs))
2996
2997        for node in gm.body_graph_0.graph.nodes:
2998            if node.op == "call_function":
2999                self.assertTrue(not node.target._schema.is_mutable)
3000        self.check_map_count(gm, 1)
3001
3002    def test_map_functionalized_aot_func(self):
3003        def map_fn(x, y):
3004            z = x + y
3005            z.add_(4)
3006            return z
3007
3008        def f(xs, y):
3009            return control_flow.map(map_fn, xs, y)
3010
3011        def f_wrapper(func):
3012            @functools.wraps(func)
3013            def wrapper(*args, **kwargs):
3014                torch._enable_functionalization(reapply_views=False)
3015                try:
3016                    return pytree.tree_map(from_fun_old, func(*args, **kwargs))
3017                finally:
3018                    torch._disable_functionalization()
3019
3020            return wrapper
3021
3022        example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
3023
3024        gm = make_fx(f_wrapper(f))(*example_inputs)
3025
3026        for node in gm.body_graph_0.graph.nodes:
3027            if node.op == "call_function":
3028                self.assertTrue(not node.target._schema.is_mutable)
3029
3030        self.assertEqual(gm(*example_inputs), f(*example_inputs))
3031
3032    # https://github.com/pytorch/pytorch/issues/126988
3033    @xfailIfTorchDynamo
3034    def test_map_functionalized_arg_mutation(self):
3035        def map_fn(x, y):
3036            y.add_(4)
3037            return x + y
3038
3039        def f(xs, y):
3040            return control_flow.map(map_fn, xs, y)
3041
3042        example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
3043        functional_f = torch.func.functionalize(f)
3044        with self.assertRaisesRegex(
3045            UnsupportedAliasMutationException, "torch.map is mutating the input!"
3046        ):
3047            functional_f(*example_inputs)
3048
3049    # https://github.com/pytorch/pytorch/issues/126988
3050    @xfailIfTorchDynamo
3051    def test_map_functionalized_elem_mutation(self):
3052        def map_fn(x, y):
3053            x.add_(4)
3054            return x + y
3055
3056        def f(xs, y):
3057            return control_flow.map(map_fn, xs, y)
3058
3059        example_inputs = (torch.ones(3, 2, 4), torch.ones(4))
3060        functional_f = torch.func.functionalize(f)
3061        with self.assertRaisesRegex(
3062            UnsupportedAliasMutationException, "torch.map is mutating the input!"
3063        ):
3064            functional_f(*example_inputs)
3065
3066    def test_cond_autograd_backward(self):
3067        def true_fn(x):
3068            return x.cos()
3069
3070        def false_fn(x):
3071            return x.sin()
3072
3073        def f(x, y):
3074            return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [y])
3075
3076        example_inputs = (
3077            torch.ones(3, 2, 4, requires_grad=True),
3078            torch.ones(4, requires_grad=True),
3079        )
3080        f(*example_inputs).sum().backward()
3081
3082        # Ensure no error is thrown when not running backward
3083        res = f(*example_inputs)
3084
3085        # Ensure no error is thrown when not running backward
3086        res_compiled = torch.compile(f)(*example_inputs)
3087        self.assertEqual(res, res_compiled)
3088
3089    # https://github.com/pytorch/pytorch/issues/126988
3090    @xfailIfTorchDynamo
3091    def test_map_functionalized_elem_alias(self):
3092        def map_fn(x):
3093            x.view(x.shape)
3094            return x
3095
3096        def f(xs):
3097            return control_flow.map(map_fn, xs)
3098
3099        example_inputs = (torch.ones(3, 2, 4),)
3100        functional_f = torch.func.functionalize(f)
3101        with self.assertRaisesRegex(
3102            UnsupportedAliasMutationException, "torch.map is aliasing the input!"
3103        ):
3104            functional_f(*example_inputs)
3105
3106    def test_nested_map_cond_real(self):
3107        def true_fn(x, y):
3108            return x * y
3109
3110        def false_fn(x, y):
3111            return x + y
3112
3113        def f(x, pred, y):
3114            return cond(pred, true_fn, false_fn, [x, y])
3115
3116        def g(pred, xs, y):
3117            return control_flow.map(f, xs, pred, y)
3118
3119        gm = make_fx(g, tracing_mode="real")(
3120            torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
3121        )
3122        pred = torch.tensor(False)
3123        x = torch.randn(3, 2, 4)
3124        y = torch.randn(4)
3125        res = gm(pred, x, y)
3126        self.assertEqual(res, g(pred, x, y))
3127        self.check_map_count(gm, 1)
3128
3129    def test_nested_map_cond_symbolic(self):
3130        def true_fn(x, y):
3131            return x * y
3132
3133        def false_fn(x, y):
3134            return x + y
3135
3136        def f(x, pred, y):
3137            return cond(pred, true_fn, false_fn, [x, y])
3138
3139        def g(pred, xs, y):
3140            return control_flow.map(f, xs, pred, y)
3141
3142        gm = make_fx(g, tracing_mode="symbolic")(
3143            torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
3144        )
3145        pred = torch.tensor(False)
3146        x = torch.randn(3, 2, 2)
3147        y = torch.randn(2)
3148        res = gm(pred, x, y)
3149        self.assertEqual(res, g(pred, x, y))
3150        self.check_map_count(gm, 1)
3151
3152    def test_nested_cond_map_cond_symbolic(self):
3153        def true_fn(x, y):
3154            return x * y
3155
3156        def false_fn(x, y):
3157            return x + y
3158
3159        def f(x, pred, y):
3160            return cond(pred, true_fn, false_fn, [x, y])
3161
3162        def g(pred, xs, y):
3163            return control_flow.map(f, xs, pred, y)
3164
3165        def main_true_fn(pred, xs, y):
3166            return g(pred, xs, y) * 2
3167
3168        def main_false_fn(pred, xs, y):
3169            return g(pred, xs, y) + 1
3170
3171        def main(p, pred, xs, y):
3172            return cond(p, main_true_fn, main_false_fn, [pred, xs, y])
3173
3174        gm = make_fx(main, tracing_mode="symbolic")(
3175            torch.tensor(True), torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4)
3176        )
3177        p = torch.tensor(False)
3178        pred = torch.tensor(False)
3179        xs = torch.randn(3, 2, 2)
3180        y = torch.randn(2)
3181        res = gm(p, pred, xs, y)
3182        self.assertEqual(res, main(p, pred, xs, y))
3183        self.check_map_count(gm, 2)
3184
3185    def test_cond_with_sym_pred(self):
3186        def true_fn(x):
3187            return x + x
3188
3189        def false_fn(x):
3190            return x * x
3191
3192        def foo(x):
3193            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
3194
3195        gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 2, 1))
3196        # The symbols in make_fx's shape_env should not be specialized.
3197        self.assertEqual(len(gm.shape_env.guards), 0)
3198
3199        self.assertExpectedInline(
3200            gm.code.strip(),
3201            """\
3202def forward(self, x_1):
3203    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
3204    eq = sym_size_int == 4;  sym_size_int = None
3205    true_graph_0 = self.true_graph_0
3206    false_graph_0 = self.false_graph_0
3207    cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]);  eq = true_graph_0 = false_graph_0 = x_1 = None
3208    getitem = cond[0];  cond = None
3209    return getitem""",  # noqa: B950
3210        )
3211
3212        # We expect the traced graph module to work even if input size changes.
3213        x = torch.ones(4, 3, 2)
3214        self.assertEqual(gm(x), true_fn(x))
3215        self.assertEqual(foo(x), true_fn(x))
3216
3217    def test_cond_with_unbacked_sym_pred(self):
3218        def foo(x):
3219            def true_fn(x):
3220                return x + x
3221
3222            def false_fn(x):
3223                return x * x
3224
3225            az = x.nonzero()
3226            return cond(az.shape[0] > 3, true_fn, false_fn, (x,))
3227
3228        gm = make_fx(foo, tracing_mode="symbolic")(torch.randn(7))
3229        self.assertExpectedInline(
3230            gm.code.strip(),
3231            """\
3232def forward(self, x_1):
3233    nonzero = torch.ops.aten.nonzero.default(x_1)
3234    sym_size_int = torch.ops.aten.sym_size.int(nonzero, 0);  nonzero = None
3235    gt = sym_size_int > 3;  sym_size_int = None
3236    true_graph_0 = self.true_graph_0
3237    false_graph_0 = self.false_graph_0
3238    cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1]);  gt = true_graph_0 = false_graph_0 = x_1 = None
3239    getitem = cond[0];  cond = None
3240    return getitem""",
3241        )
3242
3243    def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num):
3244        assert isinstance(args, (tuple, list))
3245        self.assertEqual(f(*args), exp_res)
3246        gm = make_fx(f)(*args)
3247        self.assertEqual(gm(*args), exp_res)
3248
3249        def cnt_placeholder(gm):
3250            return len([node for node in gm.graph.nodes if node.op == "placeholder"])
3251
3252        placeholder_cnts = [cnt_placeholder(mod) for mod in gm.children()]
3253        self.assertTrue(all(cnt == exp_arg_num for cnt in placeholder_cnts))
3254
3255    def _check_closure_correctly_lifted_with_mutation(
3256        self, f, closures_to_be_mutated, *, args, exp_arg_num
3257    ):
3258        exp_res = f(*args)
3259        self._check_closure_correctly_lifted(
3260            f, args=args, exp_res=exp_res, exp_arg_num=exp_arg_num
3261        )
3262
3263        for closure in closures_to_be_mutated:
3264            closure.add(-1)
3265        new_exp_res = f(*args)
3266
3267        self._check_closure_correctly_lifted(
3268            f, args=args, exp_res=new_exp_res, exp_arg_num=exp_arg_num
3269        )
3270
3271    def test_cond_with_tensor_closure(self):
3272        a = torch.ones(2, 3)
3273        b = torch.ones(2, 3) + 1
3274
3275        def true_fn(x):
3276            return x + a
3277
3278        def false_fn(x):
3279            return x + b
3280
3281        def foo(x):
3282            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
3283
3284        # expected branches takes [x, a, b] as input
3285        inp = torch.randn(2, 3)
3286        self._check_closure_correctly_lifted_with_mutation(
3287            foo, (a, b), args=(inp,), exp_arg_num=3
3288        )
3289
3290    def test_cond_with_tensor_closure_graph_module(self):
3291        a = torch.ones(2, 3)
3292        b = torch.ones(2, 3) + 1
3293
3294        def true_fn(x):
3295            return x + a
3296
3297        def false_fn(x):
3298            return x + b
3299
3300        def foo(x):
3301            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
3302
3303        # expected branches takes [x, a, b] as input
3304        inp = torch.randn(2, 3)
3305
3306        gm = make_fx(foo, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp)
3307
3308        self.assertExpectedInline(
3309            gm.code.strip(),
3310            """\
3311def forward(self, x_1):
3312    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
3313    eq = sym_size_int == 4;  sym_size_int = None
3314    true_graph_0 = self.true_graph_0
3315    false_graph_0 = self.false_graph_0
3316    _tensor_constant0 = self._tensor_constant0
3317    _tensor_constant1 = self._tensor_constant1
3318    cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, _tensor_constant1]);  eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = _tensor_constant1 = None
3319    getitem = cond[0];  cond = None
3320    return getitem""",  # noqa: B950
3321        )
3322        self.assertExpectedInline(
3323            gm.true_graph_0.code.strip(),
3324            """\
3325def forward(self, arg0_1, arg1_1, arg2_1):
3326    add = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
3327    return (add,)""",
3328        )
3329
3330    def test_cond_with_module_param_closure(self):
3331        class Mod(torch.nn.Module):
3332            def __init__(self) -> None:
3333                super().__init__()
3334                self.register_parameter(
3335                    "param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False)
3336                )
3337                self.buffer = torch.nn.Buffer(torch.ones(2, 3) + 1)
3338
3339        my_mode = Mod()
3340
3341        def true_fn(x):
3342            return x + my_mode.param
3343
3344        def false_fn(x):
3345            return x + my_mode.buffer
3346
3347        def foo(x):
3348            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
3349
3350        inp = torch.ones(2, 3)
3351        # expected both branches takes (x, param, buffer)
3352        self._check_closure_correctly_lifted_with_mutation(
3353            foo, (my_mode.param, my_mode.buffer), args=(inp,), exp_arg_num=3
3354        )
3355
3356    def test_cond_with_module_python_scalar_closure(self):
3357        def foo(x):
3358            a = torch.ones(1, 1)
3359            b = 1
3360
3361            def true_fn(x):
3362                return x + a
3363
3364            def false_fn(x):
3365                return x + b
3366
3367            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
3368
3369        inp = torch.ones(2, 3)
3370        res = inp + 1
3371        # python scalar b is not lifted as input, so both branches take (x, a)
3372        self._check_closure_correctly_lifted(
3373            foo, args=(inp,), exp_res=res, exp_arg_num=2
3374        )
3375
3376    def test_cond_nested_with_closure(self):
3377        a = torch.ones(1, 1)
3378        b = torch.ones(1, 1) + 1
3379
3380        def inner_true_fn(x):
3381            return x + a
3382
3383        def inner_false_fn(x):
3384            return x + b
3385
3386        def foo(x):
3387            def true_fn(x):
3388                return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x])
3389
3390            def false_fn(x):
3391                return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x])
3392
3393            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
3394
3395        inp = torch.ones(2, 3)
3396        # For top-level cond, it take 3 arguments (x, a, b). Dynamo should
3397        # realize that the nonlocal variables are same for the true and false
3398        # branches, so it should de-dupe them.
3399        # For second-level conds, it takes (x, a, b)
3400        self._check_closure_correctly_lifted_with_mutation(
3401            foo, (a, b), args=(inp,), exp_arg_num=3
3402        )
3403
3404    def test_cond_nested_with_closure_graph_module(self):
3405        a = torch.ones(1, 1)
3406        b = torch.ones(1, 1) + 1
3407
3408        def inner_true_fn(x):
3409            return x + a
3410
3411        def inner_false_fn(x):
3412            return x + b
3413
3414        def foo(x):
3415            def true_fn(x):
3416                return cond(x.shape[0] == 2, inner_true_fn, inner_false_fn, [x])
3417
3418            def false_fn(x):
3419                return cond(x.shape[0] > 4, inner_true_fn, inner_false_fn, [x])
3420
3421            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
3422
3423    def test_map_unfunc_boolean_tensor_for_nested_map_cond(self):
3424        def map_fn(pred, x):
3425            def fn(x, pred):
3426                return control_flow.cond(pred, lambda x: x * 2, lambda x: x / 2, (x,))
3427
3428            return control_flow.map(fn, x, pred)
3429
3430        def f_wrapper(func):
3431            @functools.wraps(func)
3432            def wrapper(*args, **kwargs):
3433                torch._enable_functionalization(reapply_views=False)
3434                try:
3435                    func_args = pytree.tree_map(
3436                        lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
3437                        args,
3438                    )
3439                    func_kwargs = pytree.tree_map(
3440                        lambda x: to_fun_old(x) if isinstance(x, torch.Tensor) else x,
3441                        kwargs,
3442                    )
3443                    return pytree.tree_map(
3444                        from_fun_old, func(*func_args, **func_kwargs)
3445                    )
3446                finally:
3447                    torch._disable_functionalization()
3448
3449            return wrapper
3450
3451        gm = make_fx(f_wrapper(map_fn))(
3452            torch.tensor(True), torch.ones([2, 3], requires_grad=False)
3453        )
3454        self.assertExpectedInline(
3455            gm.code.strip(),
3456            """\
3457def forward(self, pred_1, x_1):
3458    body_graph_0 = self.body_graph_0
3459    map_impl = torch.ops.higher_order.map_impl(body_graph_0, [x_1], [pred_1]);  body_graph_0 = x_1 = pred_1 = None
3460    getitem = map_impl[0];  map_impl = None
3461    return getitem""",
3462        )
3463        self.assertExpectedInline(
3464            gm.body_graph_0.code.strip(),
3465            """\
3466def forward(self, arg0_1, arg1_1):
3467    true_graph_0 = self.true_graph_0
3468    false_graph_0 = self.false_graph_0
3469    cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]);  arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
3470    getitem = cond[0];  cond = None
3471    return [getitem]""",  # noqa: B950
3472        )
3473
3474    def test_cond_make_fx_preserve_stack_trace_for_nodes_in_subgraph(self):
3475        def true_fn(x):
3476            return x + x.cos()
3477
3478        def false_fn(x):
3479            return x * x.sin()
3480
3481        def foo(x):
3482            return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
3483
3484        inp = torch.randn([4, 3])
3485        gm, _ = torch._dynamo.export(foo)(inp)
3486
3487        def run_with_interpreter(*args):
3488            with torch.fx.traceback.preserve_node_meta():
3489                return torch.fx.Interpreter(gm).run(*args)
3490
3491        new_gm = make_fx(run_with_interpreter)(inp)
3492
3493        checked_ops = {"add", "mul", "sin", "cos"}
3494        checked_meta = ["source_fn_stack", "stack_trace"]
3495        all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta)
3496        new_source_fns = collect_meta_for_filtered_nodes(
3497            new_gm, checked_ops, checked_meta
3498        )
3499        self.assertEqual(all_source_fns, new_source_fns)
3500
3501    @unittest.skipIf(
3502        TEST_WITH_TORCHDYNAMO,
3503        "triggers cache limit for foo and changes unique_graphs count.",
3504    )
3505    def test_cond_no_dynamo_cache_limit(self):
3506        torch._dynamo.reset()
3507        counters = torch._dynamo.utils.counters
3508        counters.clear()
3509
3510        def foo(x, true_fn, false_fn):
3511            return cond(x.sum() < 0, true_fn, false_fn, (x,))
3512
3513        inp = torch.ones(3, 4)
3514        exp_out = inp.sin()
3515        iter_n = torch._dynamo.config.cache_size_limit + 1
3516
3517        # Need this because Dynamo checks lambda code ID not object itself.
3518        def make_dummy_fn(op):
3519            exec(f"temp = lambda x: x.{op}()")
3520            return locals()["temp"]
3521
3522        for _ in range(iter_n):
3523            # each lambda has a different object id thus fails the guard
3524            self.assertEqual(
3525                foo(inp, make_dummy_fn("cos"), make_dummy_fn("sin")), exp_out
3526            )
3527
3528        # each iteration captures a cond and a getitem from the tuple output
3529        self.assertEqual(counters["stats"]["calls_captured"], iter_n * 2)
3530        self.assertEqual(counters["stats"]["unique_graphs"], iter_n)
3531
3532    def test_cond_with_consecutive_make_fx_symbolic(self):
3533        def true_fn(x):
3534            return x - x.cos()
3535
3536        def false_fn(x):
3537            return x + x.sin()
3538
3539        def foo(x):
3540            return cond(x.shape[0] == 4, true_fn, false_fn, [x])
3541
3542        inps = (torch.ones(3, 4), torch.ones(3, 5), torch.ones(5, 4), torch.ones(5, 3))
3543        for inp in inps:
3544            gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 4))
3545            self.assertExpectedInline(
3546                gm.code.strip(),
3547                """\
3548def forward(self, x_1):
3549    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
3550    eq = sym_size_int == 4;  sym_size_int = None
3551    true_graph_0 = self.true_graph_0
3552    false_graph_0 = self.false_graph_0
3553    cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1]);  eq = true_graph_0 = false_graph_0 = x_1 = None
3554    getitem = cond[0];  cond = None
3555    return getitem""",  # noqa: B950
3556            )
3557
3558            self.assertExpectedInline(
3559                gm.true_graph_0.code.strip(),
3560                """\
3561def forward(self, arg0_1):
3562    cos = torch.ops.aten.cos.default(arg0_1)
3563    sub = torch.ops.aten.sub.Tensor(arg0_1, cos);  arg0_1 = cos = None
3564    return (sub,)""",
3565            )
3566
3567            self.assertExpectedInline(
3568                gm.false_graph_0.code.strip(),
3569                """\
3570def forward(self, arg0_1):
3571    sin = torch.ops.aten.sin.default(arg0_1)
3572    add = torch.ops.aten.add.Tensor(arg0_1, sin);  arg0_1 = sin = None
3573    return (add,)""",
3574            )
3575
3576    def _create_test_fns_for_cond(
3577        self, pred, inner_most_fn, operands, closure_list, nested_level
3578    ):
3579        if nested_level == 0:
3580            if len(closure_list) > 0:
3581
3582                def true_fn(*operands):
3583                    return inner_most_fn(*operands) + inner_most_fn(*closure_list)
3584
3585                def false_fn(*operands):
3586                    return inner_most_fn(*operands) - inner_most_fn(*closure_list)
3587
3588            else:
3589
3590                def true_fn(*operands):
3591                    return inner_most_fn(*operands)
3592
3593                def false_fn(*operands):
3594                    return inner_most_fn(*operands)
3595
3596            def fn(*operands):
3597                if len(operands) == 0 and len(closure_list) == 0:
3598                    return torch.zeros(1)
3599                return cond(pred, true_fn, false_fn, operands)
3600
3601            return operands, fn
3602        else:
3603            args, inner_fn = self._create_test_fns_for_cond(
3604                pred <= 0, inner_most_fn, operands, closure_list, nested_level - 1
3605            )
3606
3607            def true_fn(*operands):
3608                return inner_most_fn(*operands) + inner_fn(*args)
3609
3610            def false_fn(*operands):
3611                return inner_most_fn(*operands) - inner_fn(*args)
3612
3613            def fn(*operands):
3614                if len(operands) == 0 and len(closure_list) == 0:
3615                    return torch.ones(1)
3616                return cond(pred, true_fn, false_fn, operands)
3617
3618            return operands, fn
3619
3620    def _init_predicate(self, pred_type):
3621        if pred_type == "bool":
3622            return True
3623        elif pred_type == "intTensor":
3624            return torch.tensor(1)
3625        elif pred_type == "floatTensor":
3626            return torch.tensor(1.0)
3627        elif pred_type == "boolTensor":
3628            return torch.tensor(False)
3629        else:
3630            raise NotImplementedError
3631
3632    def _init_fn(self, inner_fn_type):
3633        if inner_fn_type == "function":
3634            return reduce_func
3635        elif inner_fn_type == "module":
3636            return ReduceMod()
3637        elif inner_fn_type == "object":
3638            return ReduceObj()
3639        else:
3640            raise NotImplementedError
3641
3642    @parametrize("predType", ["bool", "intTensor", "floatTensor", "boolTensor"])
3643    @parametrize("innerFnType", ["function", "module", "object"])
3644    @parametrize("nOperands", [0, 1])
3645    @parametrize("nClosure", [0, 1])
3646    @parametrize("nesting", [0, 2])
3647    def test_cond_tracing_with_valid_inputs(
3648        self, predType, innerFnType, nOperands, nClosure, nesting
3649    ):
3650        pred = self._init_predicate(predType)
3651        inner_fn = self._init_fn(innerFnType)
3652        operands = [torch.ones(2, 3) + i for i in range(nOperands)]
3653        closure = [torch.ones(2, 3) - i for i in range(nClosure)]
3654        args, fn = self._create_test_fns_for_cond(
3655            pred, inner_fn, operands, closure, nesting
3656        )
3657        eager_res = fn(*args)
3658        for tracing_mode in ["symbolic", "fake", "real"]:
3659            # set _allow_non_fake_inputs = True to allow fake prop through closures
3660            with self.subTest(tracing_mode=tracing_mode):
3661                gm = make_fx(
3662                    fn, tracing_mode=tracing_mode, _allow_non_fake_inputs=True
3663                )(*args)
3664                self.assertEqual(gm(*args), eager_res)
3665
3666    @parametrize("predType", ["boolTensor"])
3667    @parametrize("innerFnType", ["function", "module", "object"])
3668    @parametrize("nOperands", [1, 2])
3669    @parametrize("nClosure", [0, 1])
3670    @parametrize("nesting", [0])
3671    def test_cond_vmap(self, predType, innerFnType, nOperands, nClosure, nesting):
3672        pred = self._init_predicate(predType)
3673        inner_fn = self._init_fn(innerFnType)
3674        operands = [torch.ones(2, 3) + i for i in range(nOperands)]
3675        closure = [torch.ones(2, 3) - i for i in range(nClosure)]
3676        args, fn = self._create_test_fns_for_cond(
3677            pred, inner_fn, operands, closure, nesting
3678        )
3679        eager_res = fn(*args)
3680        out = torch.vmap(fn)(*args)
3681        if nClosure == 0:
3682            self.assertEqual(eager_res, out)
3683        else:
3684            self.assertEqual(eager_res, out[0])
3685            self.assertEqual(eager_res, out[1])
3686
3687    def test_cond_vmap_simple(self):
3688        def fn(x):
3689            return torch.cond(
3690                pred=torch.tensor([True]),
3691                true_fn=lambda x: x + 100,
3692                false_fn=lambda x: x,
3693                operands=(x,),
3694            )
3695
3696        a = torch.arange(15).reshape((3, 5))
3697        res = torch.vmap(fn, in_dims=(0,))(a)
3698        self.assertEqual(res.shape, (3, 5))
3699        self.assertEqual(res, a + 100)
3700
3701    def test_cond_vmap_multiple_inputs(self):
3702        def fn(x, y):
3703            return torch.cond(
3704                pred=x.sum() < y.sum(),
3705                true_fn=lambda x, y: x + 100,
3706                false_fn=lambda x, y: y,
3707                operands=(x, y),
3708            )
3709
3710        a = torch.arange(15).reshape(3, 5)
3711        b = torch.ones_like(a) + 3
3712        res = torch.vmap(fn, in_dims=(0, 0))(a, b)
3713        expected = torch.tensor(
3714            [[100, 101, 102, 103, 104], [4, 4, 4, 4, 4], [4, 4, 4, 4, 4]]
3715        )
3716        self.assertEqual(res.shape, (3, 5))
3717        self.assertEqual(expected, res)
3718
3719    def test_cond_vmap_single_input_with_closure(self):
3720        a = torch.ones((3, 5)) + 3
3721        c = torch.arange(5)
3722
3723        def fn(x):
3724            return torch.cond(
3725                pred=torch.tensor([True]),
3726                true_fn=lambda x: x + c,
3727                false_fn=lambda x: x - c,
3728                operands=(x,),
3729            )
3730
3731        res = torch.vmap(fn, in_dims=(0,))(
3732            a,
3733        )
3734        with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
3735            res = torch.vmap(fn, in_dims=(0,))(
3736                a,
3737            )
3738        self.assertEqual(a + c, res)
3739
3740    def test_cond_vmap_multiple_args_with_closure(self):
3741        a = torch.ones((3, 5), dtype=torch.int64) + 3
3742        b = torch.arange(15).reshape(3, 5)
3743        c = torch.arange(5)
3744
3745        def fn(x, y):
3746            return torch.cond(
3747                pred=torch.tensor([False]),
3748                true_fn=lambda x, y: x + c,
3749                false_fn=lambda x, y: y - c,
3750                operands=(x, y),
3751            )
3752
3753        res = torch.vmap(fn)(a, b)
3754        self.assertEqual(b - c, res)
3755
3756    @parametrize("nClosure", [0, 1])
3757    def test_cond_vmap_multiple_outputs(self, nClosure):
3758        if nClosure:
3759            c = torch.ones(5, dtype=torch.int64) + 5
3760
3761            def fn(x):
3762                return torch.cond(
3763                    pred=torch.tensor([True]),
3764                    true_fn=lambda x: (x + c, x - c),
3765                    false_fn=lambda x: (x, x),
3766                    operands=(x,),
3767                )
3768
3769        else:
3770
3771            def fn(x):
3772                return torch.cond(
3773                    pred=torch.tensor([True]),
3774                    true_fn=lambda x: (x + 1, x - 1),
3775                    false_fn=lambda x: (x, x),
3776                    operands=(x,),
3777                )
3778
3779        a = torch.arange(15).reshape(3, 5)
3780        res = torch.vmap(fn)(
3781            a,
3782        )
3783        self.assertEqual(len(res), 2)
3784        if nClosure:
3785            self.assertEqual(res, (a + c, a - c))
3786        else:
3787            self.assertEqual(res, (a + 1, a - 1))
3788
3789    def test_vmap_vmap(self):
3790        def fn(x):
3791            return torch.cond(
3792                pred=torch.tensor([True]),
3793                true_fn=lambda x: x + 1,
3794                false_fn=lambda x: x - 1,
3795                operands=(x,),
3796            )
3797
3798        def wrapper(x):
3799            return torch.vmap(fn)(x)
3800
3801        a = torch.ones((3, 4, 5))
3802        res = torch.vmap(wrapper)(a)
3803        self.assertEqual(res, a + 1)
3804
3805    def test_cond_trace_set__and_mutate_input(self):
3806        def f(a, tmp):
3807            a_view = a.view(-1)
3808            with torch.no_grad():
3809                a.set_(tmp)
3810                a_view.mul_(2)
3811            return a + tmp
3812
3813        inp = torch.ones(3, 3, requires_grad=True)
3814        tmp = torch.ones(3, 3, requires_grad=True)
3815        # graph break: torch._dynamo.exc.Unsupported: call_function DelayGraphBreakVariable() [TensorVariable()] {}
3816        # due to set_
3817        with self.assertRaisesRegex(
3818            torch._dynamo.exc.UncapturedHigherOrderOpError,
3819            "Cond doesn't work unless it is captured completely with torch.compile",
3820        ):
3821            torch.cond(inp.sum() > 0, f, f, (inp, tmp))
3822
3823    def test_cond_trace_set__and_mutate_intermediate(self):
3824        def f(a, tmp):
3825            a = a.clone()
3826            a_view = a.view(-1)
3827            tmp = tmp.clone()
3828            with torch.no_grad():
3829                a.set_(tmp)
3830                a_view.mul_(2)
3831            return a + tmp
3832
3833        inp = torch.ones(3, 3, requires_grad=True)
3834        tmp = torch.ones(3, 3, requires_grad=True)
3835
3836        class Mod(torch.nn.Module):
3837            def forward(self, inp: torch.Tensor, tmp: torch.Tensor) -> torch.Tensor:
3838                return torch.cond(inp.sum() > 0, f, f, (inp, tmp))
3839
3840        with self.assertRaisesRegex(
3841            RuntimeError, "cannot mutate tensors with frozen storage"
3842        ):
3843            out = torch.compile(Mod(), backend="aot_eager")(inp, tmp)
3844
3845        with self.assertRaisesRegex(
3846            RuntimeError, "cannot mutate tensors with frozen storage"
3847        ):
3848            out = torch.compile(Mod(), backend="inductor")(inp, tmp)
3849
3850        from torch._dynamo.testing import EagerAndRecordGraphs
3851
3852        backend = EagerAndRecordGraphs()
3853        out = torch.compile(Mod(), backend=backend)(inp, tmp)
3854        self.assertExpectedInline(
3855            backend.graphs[0].cond_true_0.code.strip("\n"),
3856            """\
3857def forward(self, l_inp_, l_tmp_):
3858    l_inp__1 = l_inp_
3859    l_tmp__1 = l_tmp_
3860    a = l_inp__1.clone();  l_inp__1 = None
3861    a_view = a.view(-1)
3862    tmp = l_tmp__1.clone();  l_tmp__1 = None
3863    _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
3864    set_ = a.set_(tmp);  set_ = None
3865    mul_ = a_view.mul_(2);  a_view = mul_ = None
3866    _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
3867    add = a + tmp;  a = tmp = None
3868    return (add,)
3869    """,
3870        )
3871        self.assertEqual(out, f(inp, tmp))
3872
3873    def test_two_hops_not_sharing_code_obj(self):
3874        pred, args = torch.tensor(True), (torch.ones(3, 3),)
3875
3876        def fn1(x):
3877            return x + 1
3878
3879        def fn2(x):
3880            return x - 1
3881
3882        from torch._dynamo.testing import CompileCounter
3883
3884        # Tests rely on automatic_dynamic = True
3885        with torch._dynamo.config.patch(automatic_dynamic_shapes=True):
3886            cnt = CompileCounter()
3887            torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args)
3888            self.assertEqual(cnt.frame_count, 1)
3889
3890            args = (torch.randn(3, 3),)
3891            # No recompilation
3892            torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, args)
3893            self.assertEqual(cnt.frame_count, 1)
3894
3895            def cond_fn(x):
3896                return x.sum() > 0
3897
3898            args = (torch.randn(4, 4),)
3899            torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args)
3900            # recompilation
3901            self.assertEqual(cnt.frame_count, 2)
3902
3903            args = (torch.randn(4, 4),)
3904            torch.compile(torch.while_loop, backend=cnt)(cond_fn, fn2, args)
3905            self.assertEqual(cnt.frame_count, 2)
3906
3907            # With recompilation due to automatic dynamic
3908            # This also proves that while_loop doesn't share code obj with cond
3909            torch.compile(torch.cond, backend=cnt)(pred, fn1, fn2, (torch.randn(4, 4),))
3910            self.assertEqual(cnt.frame_count, 3)
3911
3912    def test_hop_raises_if_not_overriding_call(self):
3913        class WrongHop(torch._ops.HigherOrderOperator):
3914            pass
3915
3916        with self.assertRaisesRegex(TypeError, "WrongHop"):
3917            wrong_hop = WrongHop("wrong_hop")
3918
3919
3920_hop_schema_test_schema_types = [
3921    "bool",
3922    "int",
3923    "float",
3924    "str",
3925    "Tensor",
3926    "SymInt",
3927    "SymBool",
3928    "GraphModule",
3929    "ScriptObj",
3930]
3931
3932
3933@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
3934class TestHopSchema(TestCase):
3935    def _get_example_val(self, ty: str):
3936        from torch.fx.experimental.sym_node import SymNode
3937        from torch.fx.experimental.symbolic_shapes import ShapeEnv
3938
3939        def create_symtype(cls, pytype, shape_env, val):
3940            from torch._dynamo.source import ConstantSource
3941
3942            symbol = shape_env.create_symbol(
3943                val,
3944                source=ConstantSource(
3945                    f"__testing_hop_schema{len(shape_env.var_to_val)}"
3946                ),
3947            )
3948            return cls(SymNode(symbol, shape_env, pytype, hint=val))
3949
3950        if ty == "bool":
3951            return True
3952        elif ty == "int":
3953            return 1
3954        elif ty == "float":
3955            return 1.0
3956        elif ty == "str":
3957            return "foo"
3958        elif ty == "Tensor":
3959            return torch.tensor(1)
3960        elif ty == "SymInt":
3961            shape_env = ShapeEnv()
3962            return create_symtype(torch.SymInt, int, shape_env, 1)
3963        elif ty == "SymBool":
3964            shape_env = ShapeEnv()
3965            return create_symtype(torch.SymBool, bool, shape_env, True)
3966        elif ty == "GraphModule":
3967
3968            def f(x):
3969                return x.sin()
3970
3971            return make_fx(f)(torch.ones(1))
3972        elif ty == "ScriptObj":
3973            from torch.testing._internal.torchbind_impls import (
3974                init_torchbind_implementations,
3975            )
3976
3977            init_torchbind_implementations()
3978            foo = torch.classes._TorchScriptTesting._Foo(3, 4)
3979            return foo
3980        else:
3981            raise NotImplementedError(ty)
3982
3983    @parametrize("schema_type", _hop_schema_test_schema_types)
3984    def test_type_gen(self, schema_type):
3985        from torchgen.gen_schema_utils import TypeGen
3986
3987        example_val = self._get_example_val(schema_type)
3988        ty = TypeGen.from_example(example_val)
3989        # Test the generated type can be parsed
3990        self.assertEqual(ty.parse(str(ty)), ty)
3991
3992    @parametrize("schema_type", _hop_schema_test_schema_types)
3993    def test_list_gen(self, schema_type):
3994        from torchgen.gen_schema_utils import TypeGen
3995
3996        example_val = self._get_example_val(schema_type)
3997        li1 = [example_val]
3998        li2 = [example_val, example_val]
3999        ty1 = TypeGen.from_example(li1)
4000        ty2 = TypeGen.from_example(li1)
4001        self.assertEqual(ty1.parse(str(ty1)), ty1)
4002        self.assertEqual(ty2.parse(str(ty2)), ty2)
4003
4004    def test_function_schema_gen(self):
4005        from torchgen.gen_schema_utils import FunctionSchemaGen
4006
4007        inps = [
4008            (schema_type + "_v", self._get_example_val(schema_type))
4009            for schema_type in _hop_schema_test_schema_types
4010        ]
4011        op_name = "test_op"
4012        schema1 = FunctionSchemaGen.from_example("test_op1", inps, torch.ones(1))
4013        schema2 = FunctionSchemaGen.from_example(
4014            "test_op2",
4015            inps,
4016            [
4017                torch.ones(1),
4018            ],
4019        )
4020        schema3 = FunctionSchemaGen.from_example(
4021            "test_op3", inps, [torch.ones(1), torch.ones(1)]
4022        )
4023        self.assertExpectedInline(
4024            str(schema1),
4025            """test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""",  # noqa: B950
4026        )
4027        self.assertExpectedInline(
4028            str(schema2),
4029            """test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""",  # noqa: B950
4030        )
4031        self.assertExpectedInline(
4032            str(schema3),
4033            """test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""",  # noqa: B950,
4034        )
4035        self.assertEqual(schema1.parse(str(schema1)), schema1)
4036        self.assertEqual(schema2.parse(str(schema2)), schema2)
4037        self.assertEqual(schema3.parse(str(schema3)), schema3)
4038
4039    def test_while_loop_schema_gen(self):
4040        fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
4041        graph = make_fx(fn)(*inp).graph
4042        while_loop_node = next(
4043            node
4044            for node in graph.nodes
4045            if node.op == "call_function"
4046            and node.target is torch.ops.higher_order.while_loop
4047        )
4048        schema = torch._library.utils.hop_schema_from_fx_node(while_loop_node)
4049        self.assertExpectedInline(
4050            str(schema),
4051            """while_loop(GraphModule cond_fn, GraphModule body_fn, Tensor[2] carried_inputs, Tensor[3] additional_inputs) -> Tensor[2]""",  # noqa: B950
4052        )
4053        self.assertEqual(schema.parse(str(schema)), schema)
4054
4055
4056instantiate_parametrized_tests(TestHopSchema)
4057instantiate_parametrized_tests(TestControlFlowTraced)
4058
4059instantiate_parametrized_tests(TestControlFlow)
4060
4061if __name__ == "__main__":
4062    run_tests()
4063