1# mypy: allow-untyped-defs 2import contextlib 3import logging 4 5import torch 6import torch._subclasses.functional_tensor 7import torch.utils._pytree as pytree 8from torch._C import DispatchKey 9from torch._C._functorch import ( 10 _add_batch_dim, 11 get_unwrapped, 12 is_batchedtensor, 13 maybe_get_bdim, 14) 15from torch._dispatch.python import suspend_functionalization 16from torch._functorch.utils import exposed_in 17from torch._guards import detect_fake_mode 18from torch._higher_order_ops.utils import ( 19 _has_potential_branch_input_alias, 20 _has_potential_branch_input_mutation, 21 _maybe_run_with_interpreter, 22 _set_compilation_env, 23 reenter_make_fx, 24 unique_graph_id, 25 UnsupportedAliasMutationException, 26) 27from torch._ops import HigherOrderOperator 28from torch._subclasses.fake_tensor import FakeTensorMode 29from torch._subclasses.functional_tensor import disable_functional_mode 30from torch.fx.experimental.proxy_tensor import ( 31 _temp_remove_pre_dispatch_torch_function_mode, 32 disable_proxy_modes_tracing, 33 ProxyTorchDispatchMode, 34 track_tensor_tree, 35) 36from torch.fx.passes.shape_prop import _extract_tensor_metadata 37from torch.utils._python_dispatch import _get_current_dispatch_mode 38 39from .utils import _from_fun, create_fw_bw_graph 40 41 42log = logging.getLogger(__name__) 43 44""" 45We're going to define a `cond_op` operation. 46In order to do this, we need implementations for each of the dispatch keys. 47""" 48 49 50class CondOp(HigherOrderOperator): 51 def __init__(self): 52 super().__init__("cond") 53 54 def __call__(self, pred, true_fn, false_fn, operands): 55 return super().__call__(pred, true_fn, false_fn, operands) 56 57 58cond_op = CondOp() 59 60 61@exposed_in("torch") 62def cond(pred, true_fn, false_fn, operands): 63 r""" 64 Conditionally applies `true_fn` or `false_fn`. 65 66 .. warning:: 67 `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and 68 doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. 69 Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype 70 71 `cond` is structured control flow operator. That is, it is like a Python if-statement, 72 but has restrictions on `true_fn`, `false_fn`, and `operands` that enable it to be 73 capturable using torch.compile and torch.export. 74 75 Assuming the constraints on `cond`'s arguments are met, `cond` is equivalent to the following:: 76 77 def cond(pred, true_branch, false_branch, operands): 78 if pred: 79 return true_branch(*operands) 80 else: 81 return false_branch(*operands) 82 83 Args: 84 pred (Union[bool, torch.Tensor]): A boolean expression or a tensor with one element, 85 indicating which branch function to apply. 86 87 true_fn (Callable): A callable function (a -> b) that is within the 88 scope that is being traced. 89 90 false_fn (Callable): A callable function (a -> b) that is within the 91 scope that is being traced. The true branch and false branch must 92 have consistent input and outputs, meaning the inputs have to be 93 the same, and the outputs have to be the same type and shape. 94 95 operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions. 96 97 Example:: 98 99 def true_fn(x: torch.Tensor): 100 return x.cos() 101 def false_fn(x: torch.Tensor): 102 return x.sin() 103 return cond(x.shape[0] > 4, true_fn, false_fn, (x,)) 104 105 Restrictions: 106 - The conditional statement (aka `pred`) must meet one of the following constraints: 107 108 - It's a `torch.Tensor` with only one element, and torch.bool dtype 109 110 - It's a boolean expression, e.g. `x.shape[0] > 10` or `x.dim() > 1 and x.shape[1] > 10` 111 112 - The branch function (aka `true_fn`/`false_fn`) must meet all of the following constraints: 113 114 - The function signature must match with operands. 115 116 - The function must return a tensor with the same metadata, e.g. shape, 117 dtype, etc. 118 119 - The function cannot have in-place mutations on inputs or global variables. 120 (Note: in-place tensor operations such as `add_` for intermediate results 121 are allowed in a branch) 122 123 .. warning:: 124 Temporal Limitations: 125 126 - The **output** of branches must be a **single Tensor**. Pytree of tensors will be supported in the future. 127 128 """ 129 if torch.compiler.is_dynamo_compiling(): 130 return cond_op(pred, true_fn, false_fn, operands) 131 132 if isinstance(pred, (bool, int, float)): 133 log.warning( 134 "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." 135 " If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool." 136 ) 137 if pred: 138 return true_fn(*operands) 139 else: 140 return false_fn(*operands) 141 142 def _validate_input(pred, true_fn, false_fn, operands): 143 if not isinstance(pred, (bool, torch.Tensor, torch.SymBool)): 144 raise RuntimeError(f"Expected pred to be bool or tensor, but got {pred}.") 145 146 if isinstance(pred, torch.Tensor) and pred.numel() != 1: 147 raise RuntimeError( 148 f"Expected pred to be bool or single-element tensor, but got {pred}." 149 ) 150 151 if not callable(true_fn) or not callable(false_fn): 152 raise RuntimeError("Expect both branches to be callbale.") 153 154 if not isinstance(operands, (tuple, list)) or pytree.tree_any( 155 lambda t: not isinstance(t, torch.Tensor), operands 156 ): 157 raise RuntimeError( 158 "Expect operands to be a tuple of possibly nested dict/list/tuple that only" 159 f"consists of tensor leaves, but got {operands}." 160 ) 161 162 _validate_input(pred, true_fn, false_fn, operands) 163 164 if not torch._dynamo.is_dynamo_supported(): 165 raise RuntimeError("torch.cond requires dynamo support.") 166 167 # Dynamo is expecting a callable with "__code__" attribute. 168 # We cannot directly pass cond_op to it. So we wrap it in a dummy function. 169 def _cond_op_wrapper(*args, **kwargs): 170 return cond_op(*args, **kwargs) 171 172 with _set_compilation_env(): 173 with torch._dynamo.utils.disable_cache_limit(): 174 with _temp_remove_pre_dispatch_torch_function_mode(): 175 return torch.compile(_cond_op_wrapper, backend="eager", fullgraph=True)( 176 pred, true_fn, false_fn, operands 177 ) 178 179 180def create_fw_bw_graph_branches(true_fn, false_fn, *operands): 181 # See Note [HOP create fw_bw graph] in create_fw_bw_graph in utils.py 182 183 with suspend_functionalization(), disable_functional_mode(): 184 with disable_proxy_modes_tracing(): 185 fw_inputs = pytree.tree_map(_from_fun, operands) 186 187 fw_outputs_true = pytree.tree_map(_from_fun, true_fn(*fw_inputs)) 188 if any( 189 not isinstance(out, torch.Tensor) 190 for out in fw_outputs_true 191 if out is not None 192 ): 193 raise RuntimeError( 194 "Expect outputs of true_fn to only contains tensors or None. " 195 f"Got types {[type(out) for out in fw_outputs_true]}." 196 ) 197 fw_outputs_false = pytree.tree_map(_from_fun, false_fn(*fw_inputs)) 198 if any( 199 not isinstance(out, torch.Tensor) 200 for out in fw_outputs_false 201 if out is not None 202 ): 203 raise RuntimeError( 204 "Expect outputs of false_fn to only contains tensors or None. " 205 f"Got types {[type(out) for out in fw_outputs_false]}." 206 ) 207 208 # TODO: There is a major issue that the create_fw_bw in the higher_order_op is invoked twice: 209 # Once in the forward path (as it should) and once in the backward path, where it shouldn't be called 210 # If we can get rid of the second invokation, it would simplify this function 211 fw_true_graph, joint_true_graph = create_fw_bw_graph( 212 true_fn, False, fw_inputs, fw_outputs_true 213 ) 214 fw_false_graph, joint_false_graph = create_fw_bw_graph( 215 false_fn, False, fw_inputs, fw_outputs_false 216 ) 217 218 return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph 219 220 221def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): 222 assert isinstance( 223 operands, (list, tuple) 224 ), "Cond operands must be a list or tuple of tensors" 225 assert all( 226 isinstance(o, torch.Tensor) for o in operands 227 ), "Cond operands must be a list of tensors" 228 229 true_graph = reenter_make_fx(true_fn)(*operands) 230 false_graph = reenter_make_fx(false_fn)(*operands) 231 232 true_outs = [] 233 false_outs = [] 234 for node in true_graph.graph.nodes: 235 if node.op == "output": 236 true_outs.extend(node.args) 237 238 for node in false_graph.graph.nodes: 239 if node.op == "output": 240 false_outs.extend(node.args) 241 242 flat_true_outs = pytree.arg_tree_leaves(*true_outs) 243 flat_false_outs = pytree.arg_tree_leaves(*false_outs) 244 if len(flat_true_outs) != len(flat_false_outs): 245 raise torch._dynamo.exc.CondOpArgsMismatchError( 246 f"Expected to return same number of outputs but got:" 247 f"\n true branch returns {len(flat_true_outs)} item(s)" 248 f"\n false branch returns {len(flat_false_outs)} item(s)" 249 ) 250 251 for i in range(0, len(flat_true_outs)): 252 true_out = flat_true_outs[i] 253 false_out = flat_false_outs[i] 254 255 # Note that we need skip the check for requires_grad because we're after 256 # after autograd key during tracing, so the rquires_grad attribute of the tensors 257 # are no longer. See Note [invariants for node meta 'val'] 258 def _same_meta_except_requires_grad(true_out, false_out): 259 if true_out is None and false_out is None: 260 return True 261 elif true_out is None or false_out is None: 262 # Consider the following case: 263 # def true_fn(x, y): 264 # return x * y 265 # 266 # def false_fn(x, y): 267 # return x.sin() 268 # 269 # We'll get the following graphs for backward: 270 # def backward_true_fn(x, y, grad_out): 271 # return grad_out * y, grad_out * x 272 # 273 # def backward_false_fn(x, y, grad_out): 274 # retrun grad_out, None 275 # 276 # This suggests that when we make_fx into the backward graph, 277 # the output graph would produce outputs with metadata, this is undesirable. 278 # 279 # Ideally, we should provide an optional type to indicate that one of the branches might 280 # return None. But we'll just let it pass for now and let downstream/runtime handle. 281 # 282 # Note that this corner case should **only** happen when user want to trace backward graph because 283 # if it's foward, dynamo will error. 284 return True 285 true_meta = true_out.meta.get("tensor_meta", None) 286 false_meta = false_out.meta.get("tensor_meta", None) 287 return ( 288 true_meta.shape == false_meta.shape 289 and true_meta.dtype == false_meta.dtype 290 and true_meta.stride == false_meta.stride 291 ) 292 293 if not _same_meta_except_requires_grad(true_out, false_out): 294 raise torch._dynamo.exc.CondOpArgsMismatchError( 295 f"Expected each tensor to have same metadata but got:" 296 f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}" 297 f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}" 298 ) 299 300 i, true_name = unique_graph_id(proxy_mode, prefix="true_graph") 301 302 false_name = f"false_graph_{i}" 303 assert not hasattr(proxy_mode.tracer.root, false_name) 304 305 proxy_mode.tracer.root.register_module(true_name, true_graph) 306 proxy_mode.tracer.root.register_module(false_name, false_graph) 307 308 args = (pred, true_graph, false_graph, operands) 309 310 proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) 311 312 out_proxy = proxy_mode.tracer.create_proxy( 313 "call_function", func_overload, proxy_args, {} 314 ) 315 316 # At this point, we're *guaranteed* that whether an output came from the 317 # true or false branch is indistinguishable. So, as this is just for tracing 318 # purposes, choose the true branch. 319 320 # TODO: the unbacked symbol allocations MUST NOT leak out, if you want to 321 # support this we need to arrange for the reenter_make_fx unbacked SymInts 322 # to be used, AND we need to arrange for some sort of unification between 323 # the two branches (but not really unification; e.g., if one branch 324 # returns [u0] and the other returns [5] this is OK but you MUST NOT 325 # conclude the result is 5. Also if one branch returns [3] and another 326 # branch returns [5] you can make it work by immediately allocating a new 327 # unbacked SymInt here). 328 ignore_fresh_unbacked = contextlib.nullcontext() 329 if (fake_mode := detect_fake_mode()) and fake_mode.shape_env: 330 ignore_fresh_unbacked = fake_mode.shape_env.ignore_fresh_unbacked_symbols() 331 332 # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in 333 # a FakeTensorMode error : 334 # `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered` 335 # TODO Sometimes the operands are not completely FakeTensor, something seems went wrong in 336 # dynamo? Because of that it runs real computation sometimes and re-triggering downstream dispatch keys. 337 with ignore_fresh_unbacked: 338 out = false_fn(*operands) 339 340 return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) 341 342 343@cond_op.py_impl(DispatchKey.CompositeExplicitAutograd) 344def cond_op_dense(pred, true_fn, false_fn, operands): 345 mode = _get_current_dispatch_mode() 346 assert mode is None, "Mode should never be enabled for CPU/CUDA key" 347 if pred: 348 return true_fn(*operands) 349 else: 350 return false_fn(*operands) 351 352 353class CondAutogradOp(torch.autograd.Function): 354 @staticmethod 355 def forward( 356 ctx, 357 pred, 358 fw_true_graph, 359 fw_false_graph, 360 joint_true_graph, 361 joint_false_graph, 362 *operands, 363 ): 364 ctx._pred = pred 365 ctx._joint_true_graph = joint_true_graph 366 ctx._joint_false_graph = joint_false_graph 367 ctx.save_for_backward(*operands) 368 369 with torch._C._AutoDispatchBelowAutograd(): 370 return cond_op(pred, fw_true_graph, fw_false_graph, operands) 371 372 @staticmethod 373 def backward(ctx, *flat_grads): 374 operands = ctx.saved_tensors 375 376 grads = cond_op( 377 ctx._pred, 378 ctx._joint_true_graph, 379 ctx._joint_false_graph, 380 flat_grads + operands, 381 ) 382 return None, None, None, None, None, *grads 383 384 385@cond_op.py_impl(DispatchKey.Autograd) 386def cond_autograd(pred, true_fn, false_fn, operands): 387 # A shortcut for the case where all inputs don't require gradient, 388 # we skip tracing the forward and backward graph. 389 if pytree.tree_all_only( 390 torch.Tensor, 391 lambda t: not t.requires_grad, # type: ignore[union-attr] 392 (pred, operands), 393 ): 394 with torch._C._AutoDispatchBelowAutograd(): 395 return cond_op(pred, true_fn, false_fn, operands) 396 397 ( 398 fw_true_graph, 399 fw_false_graph, 400 joint_true_graph, 401 joint_false_graph, 402 ) = create_fw_bw_graph_branches(true_fn, false_fn, *operands) 403 flat_out = CondAutogradOp.apply( 404 pred, 405 fw_true_graph, 406 fw_false_graph, 407 joint_true_graph, 408 joint_false_graph, 409 *operands, 410 ) 411 return flat_out 412 413 414@cond_op.py_impl(ProxyTorchDispatchMode) 415def inner(mode, pred, true_fn, false_fn, operands): 416 return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands) 417 418 419@cond_op.py_impl(FakeTensorMode) 420def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): 421 # Ignore here, because if you've gotten here but you're not manually 422 # tracing the inner graphs, that means that you intend to reuse the graph 423 # directly. Which means the old unbacked symbol bindings are appropriate. 424 # This strategy will not work if unbacked symbols can escape. 425 ignore_fresh_unbacked = contextlib.nullcontext() 426 if mode.shape_env: 427 ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols() 428 429 with mode, ignore_fresh_unbacked: 430 true_outs = true_fn(*operands) 431 flat_true_outs = pytree.tree_leaves(true_outs) 432 flat_false_outs = pytree.tree_leaves(false_fn(*operands)) 433 if len(flat_true_outs) != len(flat_false_outs): 434 raise RuntimeError("Unmatched number of outputs from cond() branches.") 435 436 for true_out, false_out in zip(flat_true_outs, flat_false_outs): 437 true_meta = _extract_tensor_metadata(true_out) 438 false_meta = _extract_tensor_metadata(false_out) 439 if true_meta != false_meta: 440 raise torch._dynamo.exc.CondOpArgsMismatchError( 441 f"Expected each tensor to have same metadata but got:" 442 f"\n {true_fn.__name__} returns {true_meta}" 443 f"\n {false_fn.__name__} returns {false_meta}" 444 ) 445 return true_outs 446 447 448@cond_op.py_functionalize_impl 449def cond_func(ctx, pred, true_fn, false_fn, inputs): 450 unwrapped_inputs = ctx.unwrap_tensors(inputs) 451 unwrapped_pred = ctx.unwrap_tensors(pred) 452 with ctx.redispatch_to_next() as m: 453 functional_true = ctx.functionalize(_maybe_run_with_interpreter(true_fn)) 454 functional_false = ctx.functionalize(_maybe_run_with_interpreter(false_fn)) 455 pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch 456 for branch in [functional_true, functional_false]: 457 if _has_potential_branch_input_mutation( 458 branch, unwrapped_inputs, pre_dispatch=pre_dispatch 459 ): 460 raise UnsupportedAliasMutationException( 461 "One of torch.cond branch might be modifying the input!" 462 ) 463 for branch in [true_fn, false_fn]: 464 if _has_potential_branch_input_alias( 465 branch, unwrapped_inputs, pre_dispatch=pre_dispatch 466 ): 467 raise UnsupportedAliasMutationException( 468 "One of torch.cond branch might be aliasing the input!" 469 ) 470 471 cond_return = cond_op( 472 unwrapped_pred, functional_true, functional_false, unwrapped_inputs 473 ) 474 return ctx.wrap_tensors(cond_return) 475 476 477@cond_op.py_impl(torch._C._functorch.TransformType.Vmap) 478def cond_batch_rule(interpreter, pred, true_fn, false_fn, inputs): 479 assert isinstance( 480 inputs, (list, tuple) 481 ), "Cond inputs must be a list or tuple of tensors" 482 assert all( 483 isinstance(i, torch.Tensor) for i in inputs 484 ), "Cond inputs must be a list of tensors" 485 486 pred_ = get_unwrapped(pred) if is_batchedtensor(pred) else pred 487 488 # unbatched tensors are not vmapped 489 tensors, in_dims = zip( 490 *[ 491 (get_unwrapped(t), maybe_get_bdim(t)) if is_batchedtensor(t) else (t, None) 492 for t in inputs 493 ] 494 ) 495 496 if is_batchedtensor(pred): 497 # prepend "pred" and vmap everything 498 tensors = (pred_,) + tensors 499 in_dims = (0,) + in_dims 500 501 def fn(p, *args): 502 t = true_fn(*args) 503 f = false_fn(*args) 504 return torch.where(p, t[0], f[0]) 505 506 with interpreter.lower(): 507 result = torch.vmap(fn, in_dims=in_dims)(*tensors) 508 509 else: 510 # predicate is known at this stage and it is a boolean expression or a 511 # tensor with one element. 512 true_fn = torch.vmap(true_fn, in_dims=in_dims) 513 false_fn = torch.vmap(false_fn, in_dims=in_dims) 514 515 with interpreter.lower(): 516 result = cond_op(pred, true_fn, false_fn, tensors) 517 518 if not isinstance(result, tuple): 519 result = (result,) 520 lvl = interpreter.level() 521 return tuple([_add_batch_dim(r, 0, lvl) for r in result]) 522