xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/cond.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import logging
4
5import torch
6import torch._subclasses.functional_tensor
7import torch.utils._pytree as pytree
8from torch._C import DispatchKey
9from torch._C._functorch import (
10    _add_batch_dim,
11    get_unwrapped,
12    is_batchedtensor,
13    maybe_get_bdim,
14)
15from torch._dispatch.python import suspend_functionalization
16from torch._functorch.utils import exposed_in
17from torch._guards import detect_fake_mode
18from torch._higher_order_ops.utils import (
19    _has_potential_branch_input_alias,
20    _has_potential_branch_input_mutation,
21    _maybe_run_with_interpreter,
22    _set_compilation_env,
23    reenter_make_fx,
24    unique_graph_id,
25    UnsupportedAliasMutationException,
26)
27from torch._ops import HigherOrderOperator
28from torch._subclasses.fake_tensor import FakeTensorMode
29from torch._subclasses.functional_tensor import disable_functional_mode
30from torch.fx.experimental.proxy_tensor import (
31    _temp_remove_pre_dispatch_torch_function_mode,
32    disable_proxy_modes_tracing,
33    ProxyTorchDispatchMode,
34    track_tensor_tree,
35)
36from torch.fx.passes.shape_prop import _extract_tensor_metadata
37from torch.utils._python_dispatch import _get_current_dispatch_mode
38
39from .utils import _from_fun, create_fw_bw_graph
40
41
42log = logging.getLogger(__name__)
43
44"""
45We're going to define a `cond_op` operation.
46In order to do this, we need implementations for each of the dispatch keys.
47"""
48
49
50class CondOp(HigherOrderOperator):
51    def __init__(self):
52        super().__init__("cond")
53
54    def __call__(self, pred, true_fn, false_fn, operands):
55        return super().__call__(pred, true_fn, false_fn, operands)
56
57
58cond_op = CondOp()
59
60
61@exposed_in("torch")
62def cond(pred, true_fn, false_fn, operands):
63    r"""
64    Conditionally applies `true_fn` or `false_fn`.
65
66    .. warning::
67        `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and
68        doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
69        Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
70
71    `cond` is structured control flow operator. That is, it is like a Python if-statement,
72    but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be
73    capturable using torch.compile and torch.export.
74
75    Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following::
76
77        def cond(pred, true_branch, false_branch, operands):
78            if pred:
79                return true_branch(*operands)
80            else:
81                return false_branch(*operands)
82
83    Args:
84        pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element,
85          indicating which branch function to apply.
86
87        true_fn (Callable): A callable function (a -> b) that is within the
88          scope that is being traced.
89
90        false_fn (Callable): A callable function (a -> b) that is within the
91          scope that is being traced. The true branch and false branch must
92          have consistent input and outputs, meaning the inputs have to be
93          the same, and the outputs have to be the same type and shape.
94
95        operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions.
96
97    Example::
98
99        def true_fn(x: torch.Tensor):
100            return x.cos()
101        def false_fn(x: torch.Tensor):
102            return x.sin()
103        return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
104
105    Restrictions:
106        - The conditional statement (aka `pred`) must meet one of the following constraints:
107
108          - It's a `torch.Tensor` with only one element, and torch.bool dtype
109
110          - It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10`
111
112        - The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints:
113
114          - The function signature must match with operands.
115
116          - The function must return a tensor with the same metadata, e.g. shape,
117            dtype, etc.
118
119          - The function cannot have in-place mutations on inputs or global variables.
120            (Note: in-place tensor operations such as `add_` for intermediate results
121            are allowed in a branch)
122
123    .. warning::
124        Temporal Limitations:
125
126        - The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future.
127
128    """
129    if torch.compiler.is_dynamo_compiling():
130        return cond_op(pred, true_fn, false_fn, operands)
131
132    if isinstance(pred, (bool, int, float)):
133        log.warning(
134            "Pred is a Python constant. When used with torch.cond, it executes only one of the branches."
135            " If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool."
136        )
137        if pred:
138            return true_fn(*operands)
139        else:
140            return false_fn(*operands)
141
142    def _validate_input(pred, true_fn, false_fn, operands):
143        if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)):
144            raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.")
145
146        if isinstance(pred, torch.Tensor) and pred.numel() != 1:
147            raise RuntimeError(
148                f"Expected pred to be bool or single-element tensor, but got {pred}."
149            )
150
151        if not callable(true_fn) or not callable(false_fn):
152            raise RuntimeError("Expect both branches to be callbale.")
153
154        if not isinstance(operands, (tuple, list)) or pytree.tree_any(
155            lambda t: not isinstance(t, torch.Tensor), operands
156        ):
157            raise RuntimeError(
158                "Expect operands to be a tuple of possibly nested dict/list/tuple that only"
159                f"consists of tensor leaves, but got {operands}."
160            )
161
162    _validate_input(pred, true_fn, false_fn, operands)
163
164    if not torch._dynamo.is_dynamo_supported():
165        raise RuntimeError("torch.cond requires dynamo support.")
166
167    # Dynamo is expecting a callable with "__code__" attribute.
168    # We cannot directly pass cond_op to it. So we wrap it in a dummy function.
169    def _cond_op_wrapper(*args, **kwargs):
170        return cond_op(*args, **kwargs)
171
172    with _set_compilation_env():
173        with torch._dynamo.utils.disable_cache_limit():
174            with _temp_remove_pre_dispatch_torch_function_mode():
175                return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)(
176                    pred, true_fn, false_fn, operands
177                )
178
179
180def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
181    # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py
182
183    with suspend_functionalization(), disable_functional_mode():
184        with disable_proxy_modes_tracing():
185            fw_inputs = pytree.tree_map(_from_fun, operands)
186
187            fw_outputs_true = pytree.tree_map(_from_fun, true_fn(*fw_inputs))
188            if any(
189                not isinstance(out, torch.Tensor)
190                for out in fw_outputs_true
191                if out is not None
192            ):
193                raise RuntimeError(
194                    "Expect outputs of true_fn to only contains tensors or None. "
195                    f"Got types {[type(out) for out in fw_outputs_true]}."
196                )
197            fw_outputs_false = pytree.tree_map(_from_fun, false_fn(*fw_inputs))
198            if any(
199                not isinstance(out, torch.Tensor)
200                for out in fw_outputs_false
201                if out is not None
202            ):
203                raise RuntimeError(
204                    "Expect outputs of false_fn to only contains tensors or None. "
205                    f"Got types {[type(out) for out in fw_outputs_false]}."
206                )
207
208            # TODO: There is a major issue that the create_fw_bw in the higher_order_op is invoked twice:
209            # Once in the forward path (as it should) and once in the backward path, where it shouldn't be called
210            # If we can get rid of the second invokation, it would simplify this function
211            fw_true_graph, joint_true_graph = create_fw_bw_graph(
212                true_fn, False, fw_inputs, fw_outputs_true
213            )
214            fw_false_graph, joint_false_graph = create_fw_bw_graph(
215                false_fn, False, fw_inputs, fw_outputs_false
216            )
217
218        return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph
219
220
221def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
222    assert isinstance(
223        operands, (list, tuple)
224    ), "Cond operands must be a list or tuple of tensors"
225    assert all(
226        isinstance(o, torch.Tensor) for o in operands
227    ), "Cond operands must be a list of tensors"
228
229    true_graph = reenter_make_fx(true_fn)(*operands)
230    false_graph = reenter_make_fx(false_fn)(*operands)
231
232    true_outs = []
233    false_outs = []
234    for node in true_graph.graph.nodes:
235        if node.op == "output":
236            true_outs.extend(node.args)
237
238    for node in false_graph.graph.nodes:
239        if node.op == "output":
240            false_outs.extend(node.args)
241
242    flat_true_outs = pytree.arg_tree_leaves(*true_outs)
243    flat_false_outs = pytree.arg_tree_leaves(*false_outs)
244    if len(flat_true_outs) != len(flat_false_outs):
245        raise torch._dynamo.exc.CondOpArgsMismatchError(
246            f"Expected to return same number of outputs but got:"
247            f"\n  true branch returns {len(flat_true_outs)} item(s)"
248            f"\n  false branch returns {len(flat_false_outs)} item(s)"
249        )
250
251    for i in range(0, len(flat_true_outs)):
252        true_out = flat_true_outs[i]
253        false_out = flat_false_outs[i]
254
255        # Note that we need skip the check for requires_grad because we're after
256        # after autograd key during tracing, so the rquires_grad attribute of the tensors
257        # are no longer. See Note [invariants for node meta 'val']
258        def _same_meta_except_requires_grad(true_out, false_out):
259            if true_out is None and false_out is None:
260                return True
261            elif true_out is None or false_out is None:
262                # Consider the following case:
263                # def true_fn(x, y):
264                #   return x * y
265                #
266                # def false_fn(x, y):
267                #   return x.sin()
268                #
269                # We'll get the following graphs for backward:
270                # def backward_true_fn(x, y, grad_out):
271                #  return grad_out * y, grad_out * x
272                #
273                # def backward_false_fn(x, y, grad_out):
274                #  retrun grad_out, None
275                #
276                # This suggests that when we make_fx into the backward graph,
277                # the output graph would produce outputs with metadata, this is undesirable.
278                #
279                # Ideally, we should provide an optional type to indicate that one of the branches might
280                # return None. But we'll just let it pass for now and let downstream/runtime handle.
281                #
282                # Note that this corner case should **only** happen when user want to trace backward graph because
283                # if it's foward, dynamo will error.
284                return True
285            true_meta = true_out.meta.get("tensor_meta", None)
286            false_meta = false_out.meta.get("tensor_meta", None)
287            return (
288                true_meta.shape == false_meta.shape
289                and true_meta.dtype == false_meta.dtype
290                and true_meta.stride == false_meta.stride
291            )
292
293        if not _same_meta_except_requires_grad(true_out, false_out):
294            raise torch._dynamo.exc.CondOpArgsMismatchError(
295                f"Expected each tensor to have same metadata but got:"
296                f"\n  {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
297                f"\n  {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
298            )
299
300    i, true_name = unique_graph_id(proxy_mode, prefix="true_graph")
301
302    false_name = f"false_graph_{i}"
303    assert not hasattr(proxy_mode.tracer.root, false_name)
304
305    proxy_mode.tracer.root.register_module(true_name, true_graph)
306    proxy_mode.tracer.root.register_module(false_name, false_graph)
307
308    args = (pred, true_graph, false_graph, operands)
309
310    proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
311
312    out_proxy = proxy_mode.tracer.create_proxy(
313        "call_function", func_overload, proxy_args, {}
314    )
315
316    # At this point, we're *guaranteed* that whether an output came from the
317    # true or false branch is indistinguishable. So, as this is just for tracing
318    # purposes, choose the true branch.
319
320    # TODO: the unbacked symbol allocations MUST NOT leak out, if you want to
321    # support this we need to arrange for the reenter_make_fx unbacked SymInts
322    # to be used, AND we need to arrange for some sort of unification between
323    # the two branches (but not really unification; e.g., if one branch
324    # returns [u0] and the other returns [5] this is OK but you MUST NOT
325    # conclude the result is 5.  Also if one branch returns [3] and another
326    # branch returns [5] you can make it work by immediately allocating a new
327    # unbacked SymInt here).
328    ignore_fresh_unbacked = contextlib.nullcontext()
329    if (fake_mode := detect_fake_mode()) and fake_mode.shape_env:
330        ignore_fresh_unbacked = fake_mode.shape_env.ignore_fresh_unbacked_symbols()
331
332    # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
333    # a FakeTensorMode error :
334    # `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
335    # TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in
336    # dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys.
337    with ignore_fresh_unbacked:
338        out = false_fn(*operands)
339
340    return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
341
342
343@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd)
344def cond_op_dense(pred, true_fn, false_fn, operands):
345    mode = _get_current_dispatch_mode()
346    assert mode is None, "Mode should never be enabled for CPU/CUDA key"
347    if pred:
348        return true_fn(*operands)
349    else:
350        return false_fn(*operands)
351
352
353class CondAutogradOp(torch.autograd.Function):
354    @staticmethod
355    def forward(
356        ctx,
357        pred,
358        fw_true_graph,
359        fw_false_graph,
360        joint_true_graph,
361        joint_false_graph,
362        *operands,
363    ):
364        ctx._pred = pred
365        ctx._joint_true_graph = joint_true_graph
366        ctx._joint_false_graph = joint_false_graph
367        ctx.save_for_backward(*operands)
368
369        with torch._C._AutoDispatchBelowAutograd():
370            return cond_op(pred, fw_true_graph, fw_false_graph, operands)
371
372    @staticmethod
373    def backward(ctx, *flat_grads):
374        operands = ctx.saved_tensors
375
376        grads = cond_op(
377            ctx._pred,
378            ctx._joint_true_graph,
379            ctx._joint_false_graph,
380            flat_grads + operands,
381        )
382        return None, None, None, None, None, *grads
383
384
385@cond_op.py_impl(DispatchKey.Autograd)
386def cond_autograd(pred, true_fn, false_fn, operands):
387    # A shortcut for the case where all inputs don't require gradient,
388    # we skip tracing the forward and backward graph.
389    if pytree.tree_all_only(
390        torch.Tensor,
391        lambda t: not t.requires_grad,  # type: ignore[union-attr]
392        (pred, operands),
393    ):
394        with torch._C._AutoDispatchBelowAutograd():
395            return cond_op(pred, true_fn, false_fn, operands)
396
397    (
398        fw_true_graph,
399        fw_false_graph,
400        joint_true_graph,
401        joint_false_graph,
402    ) = create_fw_bw_graph_branches(true_fn, false_fn, *operands)
403    flat_out = CondAutogradOp.apply(
404        pred,
405        fw_true_graph,
406        fw_false_graph,
407        joint_true_graph,
408        joint_false_graph,
409        *operands,
410    )
411    return flat_out
412
413
414@cond_op.py_impl(ProxyTorchDispatchMode)
415def inner(mode, pred, true_fn, false_fn, operands):
416    return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
417
418
419@cond_op.py_impl(FakeTensorMode)
420def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
421    # Ignore here, because if you've gotten here but you're not manually
422    # tracing the inner graphs, that means that you intend to reuse the graph
423    # directly.  Which means the old unbacked symbol bindings are appropriate.
424    # This strategy will not work if unbacked symbols can escape.
425    ignore_fresh_unbacked = contextlib.nullcontext()
426    if mode.shape_env:
427        ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
428
429    with mode, ignore_fresh_unbacked:
430        true_outs = true_fn(*operands)
431        flat_true_outs = pytree.tree_leaves(true_outs)
432        flat_false_outs = pytree.tree_leaves(false_fn(*operands))
433    if len(flat_true_outs) != len(flat_false_outs):
434        raise RuntimeError("Unmatched number of outputs from cond() branches.")
435
436    for true_out, false_out in zip(flat_true_outs, flat_false_outs):
437        true_meta = _extract_tensor_metadata(true_out)
438        false_meta = _extract_tensor_metadata(false_out)
439        if true_meta != false_meta:
440            raise torch._dynamo.exc.CondOpArgsMismatchError(
441                f"Expected each tensor to have same metadata but got:"
442                f"\n  {true_fn.__name__} returns {true_meta}"
443                f"\n  {false_fn.__name__} returns {false_meta}"
444            )
445    return true_outs
446
447
448@cond_op.py_functionalize_impl
449def cond_func(ctx, pred, true_fn, false_fn, inputs):
450    unwrapped_inputs = ctx.unwrap_tensors(inputs)
451    unwrapped_pred = ctx.unwrap_tensors(pred)
452    with ctx.redispatch_to_next() as m:
453        functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn))
454        functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn))
455        pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
456        for branch in [functional_true, functional_false]:
457            if _has_potential_branch_input_mutation(
458                branch, unwrapped_inputs, pre_dispatch=pre_dispatch
459            ):
460                raise UnsupportedAliasMutationException(
461                    "One of torch.cond branch might be modifying the input!"
462                )
463        for branch in [true_fn, false_fn]:
464            if _has_potential_branch_input_alias(
465                branch, unwrapped_inputs, pre_dispatch=pre_dispatch
466            ):
467                raise UnsupportedAliasMutationException(
468                    "One of torch.cond branch might be aliasing the input!"
469                )
470
471        cond_return = cond_op(
472            unwrapped_pred, functional_true, functional_false, unwrapped_inputs
473        )
474        return ctx.wrap_tensors(cond_return)
475
476
477@cond_op.py_impl(torch._C._functorch.TransformType.Vmap)
478def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs):
479    assert isinstance(
480        inputs, (list, tuple)
481    ), "Cond inputs must be a list or tuple of tensors"
482    assert all(
483        isinstance(i, torch.Tensor) for i in inputs
484    ), "Cond inputs must be a list of tensors"
485
486    pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred
487
488    # unbatched tensors are not vmapped
489    tensors, in_dims = zip(
490        *[
491            (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None)
492            for t in inputs
493        ]
494    )
495
496    if is_batchedtensor(pred):
497        # prepend "pred" and vmap everything
498        tensors = (pred_,) + tensors
499        in_dims = (0,) + in_dims
500
501        def fn(p, *args):
502            t = true_fn(*args)
503            f = false_fn(*args)
504            return torch.where(p, t[0], f[0])
505
506        with interpreter.lower():
507            result = torch.vmap(fn, in_dims=in_dims)(*tensors)
508
509    else:
510        # predicate is known at this stage and it is a boolean expression or a
511        # tensor with one element.
512        true_fn = torch.vmap(true_fn, in_dims=in_dims)
513        false_fn = torch.vmap(false_fn, in_dims=in_dims)
514
515        with interpreter.lower():
516            result = cond_op(pred, true_fn, false_fn, tensors)
517
518    if not isinstance(result, tuple):
519        result = (result,)
520    lvl = interpreter.level()
521    return tuple([_add_batch_dim(r, 0, lvl) for r in result])
522