1# mypy: ignore-errors 2 3import contextlib 4import functools 5import itertools 6import logging 7import types 8 9from typing import Dict, List, Optional, TYPE_CHECKING 10 11import torch._C 12import torch.fx 13import torch.nn 14import torch.onnx.operators 15from torch._dynamo.utils import get_fake_value 16from torch._dynamo.variables import ConstantVariable 17from torch._dynamo.variables.base import VariableTracker 18from torch._dynamo.variables.builtin import BuiltinVariable 19from torch._dynamo.variables.functions import UserFunctionVariable 20from torch._dynamo.variables.tensor import SymNodeVariable 21from torch._guards import Source 22from torch._ops import HigherOrderOperator 23from torch.fx.passes.shape_prop import _extract_tensor_metadata 24from torch.utils import _pytree as pytree 25from .. import variables 26 27from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported 28from ..source import AttrSource 29from ..utils import proxy_args_kwargs 30from .dicts import ConstDictVariable 31from .lazy import LazyVariableTracker 32from .lists import ListVariable, TupleVariable 33 34if TYPE_CHECKING: 35 from torch._dynamo.symbolic_convert import InstructionTranslator 36 37 38log = logging.getLogger(__name__) 39 40 41def raise_hard_error_if_graph_break(reason): 42 def deco(fn): 43 @functools.wraps(fn) 44 def graph_break_as_hard_error(*args, **kwargs): 45 try: 46 return fn(*args, **kwargs) 47 except Unsupported as e: 48 msg = " Scroll up to find out what causes the graph break." 49 raise UncapturedHigherOrderOpError(reason + msg) from e 50 51 return graph_break_as_hard_error 52 53 return deco 54 55 56@contextlib.contextmanager 57def dynamo_enable_grad(tx, enable=True): 58 from . import GradModeVariable 59 60 org_value = torch.is_grad_enabled() 61 try: 62 GradModeVariable.create(tx, enable, initialized=True) 63 yield 64 finally: 65 GradModeVariable.create(tx, org_value, initialized=True) 66 67 68def only_consist_of(var, types, allow_none=False): 69 if isinstance(var, types): 70 return True 71 if allow_none and var.is_python_constant() and var.as_python_constant() is None: 72 return True 73 if isinstance(var, (TupleVariable, ListVariable)): 74 return all(only_consist_of(item, types, allow_none) for item in var.items) 75 if isinstance(var, ConstDictVariable): 76 return all( 77 only_consist_of(item, types, allow_none) for item in var.items.values() 78 ) 79 return False 80 81 82# A more read-able syntax sugar for creating a UserFunctionVariable for f 83# and run call_function on it. Make it return a function to preserve the calling 84# convention of the original f. 85def _make_inlined(tx, f): 86 assert callable(f), "Expect f to be a python callable." 87 88 def inline_call(*args, **kwargs): 89 return UserFunctionVariable(f).call_function(tx, args, kwargs) 90 91 return inline_call 92 93 94def _call_function_and_unflatten_output( 95 tx, fn, args, kwargs, flat_example_value, ret_treespec 96): 97 from .builder import wrap_fx_proxy 98 99 # Store the invocation as a call 100 flat_variable = wrap_fx_proxy( 101 tx=tx, 102 proxy=tx.output.create_proxy( 103 "call_function", 104 fn, 105 args=args, 106 kwargs=kwargs, 107 ), 108 example_value=flat_example_value, 109 ) 110 111 # Transform variable back into a list (previously made into a tuple by 112 # speculate_subgraph function) so as to respect the pytree API typing. 113 flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {}) 114 return ( 115 _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_treespec) 116 if ret_treespec 117 else flat_variable 118 ) 119 120 121def _assert_tensors_nonaliasing(inputs, outputs): 122 input_tensor_ids = { 123 id(t) for t in pytree.tree_leaves(inputs) if isinstance(t, torch.Tensor) 124 } 125 output_tensor_ids = { 126 id(t) for t in pytree.tree_leaves(outputs) if isinstance(t, torch.Tensor) 127 } 128 assert input_tensor_ids.isdisjoint( 129 output_tensor_ids 130 ), "inputs to function body cannot alias outputs" 131 132 133def _check_supported_callable_arg(tx, func_var: VariableTracker, arg_name): 134 is_callable = ( 135 BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant() 136 ) 137 if not is_callable: 138 unimplemented(f"{arg_name} is of unsupported callable type {str(func_var)}.") 139 140 141def validate_args_and_maybe_create_graph_inputs( 142 sub_args, 143 tracer, 144 tx, 145 set_subgraph_inputs, 146 description, 147): 148 from . import AutogradFunctionContextVariable 149 from .builder import wrap_fx_proxy_cls 150 151 assert tracer.parent is not None 152 153 if set_subgraph_inputs == "flatten_manual": 154 flat_args, tree_spec = _make_inlined(tx, pytree.tree_flatten)( 155 ListVariable(sub_args) 156 ).unpack_var_sequence(tx) 157 158 flat_inputs = validate_args_and_maybe_create_graph_inputs( 159 flat_args.unpack_var_sequence(tx), 160 tracer, 161 tx, 162 set_subgraph_inputs="manual", 163 description=description, 164 ) 165 166 return _make_inlined(tx, pytree.tree_unflatten)( 167 ListVariable(flat_inputs), tree_spec 168 ).unpack_var_sequence(tx) 169 else: 170 args = [] 171 for a in sub_args: 172 assert isinstance(a, VariableTracker) 173 if set_subgraph_inputs == "automatic": 174 args.append(a) 175 continue 176 elif set_subgraph_inputs == "semi_automatic": 177 if isinstance(a, AutogradFunctionContextVariable): 178 tracer.create_graph_input(a.as_proxy().node.name) 179 elif a.maybe_fx_node() is not None: 180 node = a.maybe_fx_node() 181 new_proxy = tracer.create_graph_input(node.name) 182 example_value = ( 183 node.meta["example_value"] 184 if "example_value" in node.meta 185 else None 186 ) 187 a = wrap_fx_proxy_cls( 188 target_cls=type(a), 189 tx=tx, 190 proxy=new_proxy, 191 example_value=example_value, 192 ) 193 args.append(a) 194 continue 195 196 if a.is_python_constant(): 197 # This arg is not used in the body of the higher order op. 198 # Currently, this new input is added to make the calls 199 # happy, which expect a fixed number of arguments. In 200 # future, we can clean this up. 201 tracer.create_graph_input("const") 202 new_arg = a 203 # Weird special case, we probably want to delete it or fold it 204 # into the next case (of `a` being placeable into a graph) 205 elif isinstance(a, AutogradFunctionContextVariable): 206 tracer.create_graph_input(a.as_proxy().node.name) 207 new_arg = a 208 # If `a` can be put into a graph 209 elif a.maybe_fx_node() is not None: 210 node = a.maybe_fx_node() 211 new_proxy = tracer.create_graph_input(node.name) 212 example_value = ( 213 node.meta["example_value"] if "example_value" in node.meta else None 214 ) 215 new_arg = wrap_fx_proxy_cls( 216 target_cls=type(a), 217 tx=tx, 218 proxy=new_proxy, 219 example_value=example_value, 220 ) 221 # If `a` cannot be put into a graph 222 else: 223 # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic"). 224 unimplemented( 225 f"{description} with body that accepts non-Tensors as input. " 226 f"Got: {a.python_type()}" 227 ) 228 args.append(new_arg) 229 return args 230 231 232# This helper function is used to make sure two graphs share the same input signature. For example, 233# in torch.cond, two branches might lift different set of tensors as inputs. This function helps to 234# dedup the inputs and modify the graphs to take the same set of inputs. 235def _merge_graph_inputs( 236 l_graph, l_lifted_freevars, l_name, r_graph, r_lifted_freevars, r_name 237): 238 def dedup_and_sort_lifted_freevars(l_lifted_freevars, r_lifted_freevars): 239 # The nn module attributes are guaranteed to be registered into the top-level graph module during 240 # higher order op speculation. Therefore, get_attr nodes in two branches with the same 241 # target refer to the same attribute and we can safely deduplicate them with their target. 242 # 243 # Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But 244 # true_branch and false_branch belong to two separate tracing contexts, they may register the same 245 # attribute to top level seperately. This creates two get_attr proxies for the same attribute 246 # that have different meta data such as stack_trace (one stack trace for the true_branch, 247 # and the other for false_branch). It seems better to discard the proxy explicitly in cond 248 # than make dynamo create a single proxy for the same get_attr target. 249 def shared_getattrs(l_lifted_proxies, r_lifted_proxies): 250 true_targets = { 251 proxy.node.target: proxy 252 for proxy in l_lifted_proxies 253 if proxy.node.op == "get_attr" 254 } 255 l_shared_getattrs = {} 256 r_shared_getattrs = {} 257 258 for false_proxy in r_lifted_proxies: 259 if ( 260 false_proxy.node.op == "get_attr" 261 and false_proxy.node.target in true_targets 262 ): 263 true_proxy = true_targets[false_proxy.node.target] 264 l_shared_getattrs[true_proxy] = true_proxy 265 r_shared_getattrs[false_proxy] = true_proxy 266 return l_shared_getattrs, r_shared_getattrs 267 268 l_shared_getattrs, r_shared_getattrs = shared_getattrs( 269 l_lifted_freevars.keys(), r_lifted_freevars.keys() 270 ) 271 272 l_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( 273 l_shared_getattrs.keys() 274 ) 275 r_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union( 276 r_shared_getattrs.keys() 277 ) 278 unique_l_freevars = l_lifted_freevars.keys() - l_shared_freevars 279 unique_r_freevars = r_lifted_freevars.keys() - r_shared_freevars 280 281 def _sort_by_name(vars): 282 return sorted(vars, key=lambda var: var.node.name) 283 284 return ( 285 list(_sort_by_name(list(l_shared_freevars))), 286 list(_sort_by_name(list(r_shared_freevars))), 287 list(_sort_by_name(list(unique_l_freevars))), 288 list(_sort_by_name(list(unique_r_freevars))), 289 ) 290 291 (l_shared, r_shared, unique_l, unique_r) = dedup_and_sort_lifted_freevars( 292 l_lifted_freevars, r_lifted_freevars 293 ) 294 295 # Let's say we capture cond(pred, true_fn, false_fn, (x,)) 296 # With set_graph_input set to automatic, 297 # true_fn has lifted variables x, a, b, c 298 # false_fn has lifted variables x, a, b, d 299 # Then fixup_branch_inps make sure both branches have the same signature, i.e.: 300 # - true_fn(x, a, b, c_true_branch, d_false_branch) 301 # - false_fn(x, a, b, c_true_branch, d_false_branch) 302 # 303 # More formally, the signature has three parts in the following order: 304 # 1. used in both branches: x, a, b 305 # 2. only used in true branches: c, suffixed with _true_branch 306 # 3. only used in false branches: d, suffixed with _false_branch 307 # Within each part, we re-order the nodes by name to have a derterministic ordering for testing. 308 def fixup_branch_inps(graph, lifted_freevars, shared, unique_l, unique_r): 309 def _insert_or_replace_phs(new_args, name_suffix): 310 for arg in new_args: 311 new_ph = graph.placeholder(arg.node.name + name_suffix) 312 # Override with new_ph if there exists a old placeholder. 313 if arg in lifted_freevars: 314 old_ph = lifted_freevars[arg].node 315 old_ph.replace_all_uses_with(new_ph) 316 # replace_all_uses_with doesn't clean users. Clean it mannually so that we could erase it. 317 old_ph.users = {} 318 graph.erase_node(old_ph) 319 320 first_not_ph_node = next( 321 node for node in graph.nodes if node.op != "placeholder" 322 ) 323 with graph.inserting_before(first_not_ph_node): 324 _insert_or_replace_phs(shared, "") 325 _insert_or_replace_phs(unique_l, "_" + l_name) 326 _insert_or_replace_phs(unique_r, "_" + r_name) 327 328 fixup_branch_inps(l_graph, l_lifted_freevars, l_shared, unique_l, unique_r) 329 fixup_branch_inps(r_graph, r_lifted_freevars, r_shared, unique_l, unique_r) 330 return l_graph, r_graph, l_shared, r_shared, unique_l, unique_r 331 332 333# See NOTE [HigherOrderOperator tracing design] for details of the design 334def speculate_subgraph( 335 tx, 336 f, 337 sub_args, 338 sub_kwargs, 339 description, 340 *, 341 # source_target is the .value of HigherOrderOpVariable and is the 342 # target of the proxy that we created for the higherOrderOperator. 343 source_target=None, 344 always_restore=False, 345 enable_grad=None, 346 # NOTE [argument `set_subgraph_inputs`] 347 # set_subgraph_inputs controls what how to construct subgraphs' placeholders from sub_args. 348 # 1. if your HOP supports arbitrary inputs, use set_subgraph_inputs="automatic" (most recommended). 349 # 2. if your HOP supports only Tensor and symnode inputs, use set_subgraph_inputs="flatten_manual" (recommended). 350 # If sub_args contain Pytree structure (e.g. dict/list/tuple/set), the sub_args will be flattened first. 351 # Then the flattened args are manually set as subgraph's placeholders. 352 # 3. if your HOP must preserve inputs that are not tensor or symnode as placeholders e.g. AutogradFunctionContextVariable 353 # use set_subgraph_inputs="manual" (not recommended). We do not recommend it in general because it has the 354 # restriction that user need to manually control how to create placeholders and VariableTrackers for the args. 355 set_subgraph_inputs="automatic", 356 restore_side_effects=True, 357 should_flatten_outputs=False, 358 # Pass in an originating tracer - this is needed for preserving context 359 # across fwd-bwd for autograd.Function 360 tracer=None, 361): 362 if sub_kwargs is None: 363 sub_kwargs = {} 364 365 assert set_subgraph_inputs in { 366 "automatic", 367 "semi_automatic", 368 "flatten_manual", 369 "manual", 370 }, "Please use one of the supported set_subgraph_inputs options." 371 372 # See NOTE [Temporary argument `set_subgraph_inputs`] 373 if sub_kwargs and set_subgraph_inputs != "automatic": 374 unimplemented("Use `set_subgraph_inputs=automatic` when passing `sub_kwargs`.") 375 376 try: 377 # ensure guards on args get installed in parent subgraph 378 f, sub_args, sub_kwargs = LazyVariableTracker.realize_all( 379 (f, sub_args, sub_kwargs), 380 ) 381 382 with tx.output.subtracer(source_target, tracer) as subtracer: 383 args = validate_args_and_maybe_create_graph_inputs( 384 sub_args, subtracer, tx, set_subgraph_inputs, description 385 ) 386 387 validate_args_and_maybe_create_graph_inputs( 388 sub_kwargs.values(), 389 subtracer, 390 tx, 391 set_subgraph_inputs="automatic", 392 description=description, 393 ) 394 395 autograd_ctx = ( 396 dynamo_enable_grad(tx, enable_grad) 397 if enable_grad is not None 398 else contextlib.nullcontext() 399 ) 400 401 # For handling side effects, we can make an argument that we don't 402 # have to do anything here. The side effects infra does a good job 403 # of graph breaking if we mutate any nonlocal or global variable 404 # while subtracing. As a result if tracing succeeds, side effects 405 # data structure will only contain read-only data structures that 406 # are put there for tracking purposes. 407 # But on the other hand, there is an argument that if we ever write 408 # a new side effect in Dynamo which does not go through the side 409 # effect infra, we can end up in bad state. 410 # Therefore we restore the side effects after tracing. The catch is 411 # that we have to special handle tensor variables. If we have seen a 412 # nonlocal variable tensor during subtracing, we want to keep a 413 # track of that tensor, so that later subtracing or the root tracer 414 # itself does not create a new proxy for the already observed tensor 415 # variable. 416 if restore_side_effects: 417 prev_side_effects = tx.output.side_effects.clone() 418 419 with autograd_ctx: 420 output = f.call_function(tx, args, sub_kwargs) 421 422 if restore_side_effects: 423 new_side_effects = tx.output.side_effects.clone() 424 prev_side_effects.track_tensor_variables_from_runahead_side_effects( 425 new_side_effects 426 ) 427 tx.output.side_effects = prev_side_effects 428 429 treespec = None 430 if should_flatten_outputs: 431 # Flatten the speculated subgraph output. 432 output, treespec = _make_inlined(tx, pytree.tree_flatten)( 433 output 434 ).unpack_var_sequence(tx) 435 # Actually, transform the list (returned by flatten) into a tuple 436 # for dynamo consistency. 437 output = BuiltinVariable(tuple).call_function(tx, [output], {}) 438 439 # Register output to graph 440 # Modeled off of compile_and_call_fx_graph 441 # TODO: support pytree output 442 # We check always_restore because we dont use the output or side effects of always_restore code, 443 # like bwd. 444 if always_restore: 445 # Nothing left to do here 446 return (output, treespec), tx.output.graph, subtracer.lifted_freevars 447 else: 448 from . import TensorVariable 449 450 if not only_consist_of(output, TensorVariable, allow_none=True): 451 unimplemented( 452 "HigherOrderOperator body's output must consist of tensors only" 453 ) 454 455 # The output proxies might not belong to this SubgraphTracer 456 # (if they are free variables that were never lifted) 457 # so lift them here. 458 output_proxies = output.as_proxy() 459 output_proxies = pytree.tree_map( 460 subtracer.maybe_lift_tracked_freevar_to_input, output_proxies 461 ) 462 463 tx.output.create_node( 464 "output", 465 "output", 466 (subtracer.create_arg((output_proxies,))), 467 {}, 468 ) 469 graph = tx.output.graph 470 graph.lint() 471 lifted_freevars = subtracer.lifted_freevars 472 473 return ( 474 (output, treespec), 475 graph, 476 lifted_freevars, 477 ) 478 479 except Unsupported as ex: 480 f_name = f"{type(f).__name__}" 481 if isinstance(f, UserFunctionVariable): 482 f_name = f.get_name() 483 msg = ( 484 f"speculate_subgraph: while introspecting {description}, we were unable " 485 f"to trace function `{f_name}` into a single graph. This means " 486 f"that Dynamo was unable to prove safety for this API and will " 487 f"fall back to eager-mode PyTorch, which could lead to a slowdown." 488 ) 489 log.info(msg) 490 log.info(ex) 491 raise ex 492 493 494def make_attr(tx, name): 495 node = tx.output.create_proxy( 496 "get_attr", 497 name, 498 (), 499 {}, 500 ) 501 return node 502 503 504def add_subgraph(tx, name, gm): 505 next_name = None 506 i = 0 507 while not next_name: 508 candidate = f"{name}_{i}" 509 if candidate in tx.output.nn_modules: 510 i += 1 511 else: 512 next_name = candidate 513 514 gm.__name__ = next_name 515 gm.torchdynamo_force_dynamic = False 516 # This graph module is not present in the user space, so it can't be 517 # accessed by a source. Set source=None. 518 tx.output.register_attr_or_module(gm, next_name, source=None) 519 return next_name 520 521 522class TorchHigherOrderOperatorVariable(VariableTracker): 523 def __init__( 524 self, value: HigherOrderOperator, source: Optional[Source] = None, **kwargs 525 ): 526 super().__init__(**kwargs) 527 self.value = value 528 self.source = source 529 530 @staticmethod 531 def make(value, source=None, **kwargs): 532 if value.__name__ == "cond": 533 return CondHigherOrderVariable(value, source, **kwargs) 534 elif value.__name__ == "while_loop": 535 return WhileLoopHigherOrderVariable(value, source, **kwargs) 536 elif value.__name__ in ("map", "map_impl"): 537 return MapHigherOrderVariable(value, source, **kwargs) 538 elif value.__name__ == "executorch_call_delegate": 539 return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs) 540 elif value.__name__ == "out_dtype": 541 return OutDtypeHigherOrderVariable(value, source, **kwargs) 542 elif value.__name__ == "wrap": 543 return WrapHigherOrderVariable(value, source, **kwargs) 544 elif value.__name__ == "flex_attention": 545 return TemplatedAttentionHigherOrderVariable(value, source, **kwargs) 546 elif value.__name__ in ( 547 "wrap_activation_checkpoint", 548 "tag_activation_checkpoint", 549 ): 550 return CheckpointHigherOrderVariable(value, source, **kwargs) 551 elif value.__name__ == "_export_tracepoint": 552 return ExportTracepointHigherOrderVariable(value, source, **kwargs) 553 elif value.__name__ == "trace_wrapped": 554 return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs) 555 elif value.__name__ == "strict_mode": 556 return StrictModeHigherOrderVariable(value, source, **kwargs) 557 elif value.__name__ == "associative_scan": 558 return AssociativeScanHigherOrderVariable(value, source, **kwargs) 559 elif value.__name__ == "call_torchbind": 560 return CallTorchbindHigherOrderVariable(value, source, **kwargs) 561 else: 562 unimplemented(f"HigherOrderOperator {value.__name__}") 563 564 def call_function( 565 self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] 566 ) -> VariableTracker: 567 unimplemented(f"HigherOrderOperator {self.value.__name__}") 568 569 570class CondHigherOrderVariable(TorchHigherOrderOperatorVariable): 571 @raise_hard_error_if_graph_break( 572 reason="Cond doesn't work unless it is captured completely with torch.compile." 573 ) 574 def call_function( 575 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 576 ) -> "VariableTracker": 577 from . import ListVariable, TensorVariable 578 579 args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) 580 581 for i, k in enumerate(["pred", "true_fn", "false_fn", "operands"]): 582 if v := kwargs.pop(k, None): 583 assert i == len( 584 args 585 ), "did not provide the right number of non-keyword args" 586 args.append(v) 587 588 if kwargs: 589 unimplemented(f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}") 590 591 # TODO(voz): Support fake tensor dispatch for recursive 592 # ops - see torch/dispatch/_dispatcher.py 593 if len(args) != 4: 594 unimplemented( 595 f"Expected 4 arguments but got {len(args)}.\n" 596 f"Usage: cond(pred, true_fn, false_fn, operands)", 597 ) 598 # predicate 599 if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable): 600 unimplemented( 601 f"Expected pred to be bool or a boolean tensor with single " 602 f"item but got {str(type(args[0]))} " 603 f"with original python type {str(args[0].python_type())}.", 604 ) 605 606 # operands 607 if not isinstance(args[3], (ListVariable, TupleVariable)): 608 unimplemented( 609 f"Expected a tuple but got {args[3].python_type()}", 610 ) 611 operands = args[3].unpack_var_sequence(tx) 612 if not only_consist_of(args[3], (TensorVariable,)): 613 unimplemented( 614 "Expect operands to be a tuple of pytrees that only consists of tensor leaves." 615 ) 616 617 # branches 618 _check_supported_callable_arg(tx, args[1], "true_fn") 619 _check_supported_callable_arg(tx, args[2], "false_fn") 620 621 # Our strategy for tracing the true/false branches of cond 622 # are to checkpoint our graphstate, run the true branch, 623 # roll it back to the checkpoint, and run the false 624 # branch, and then merge the graphstates. Well, perhaps 625 # "merge" is too strong a word: we mostly assert that 626 # the resulting graphstates have to be the same. 627 # 628 # We only permit guards to diverge (we union the guards from 629 # both branches). In particular, this means that side 630 # effects are NOT permitted inside true/false branches; this 631 # would be difficult to implement, because of the path 632 # explosion problem. 633 634 def speculate_branch(branch): 635 # NB: 0 is predicate 636 ix = 1 if branch else 2 637 # TODO: Support kwargs 638 ( 639 (ret_val, ret_treespec), 640 ret_graph, 641 ret_lifted_freevars, 642 ) = speculate_subgraph( 643 tx, 644 args[ix], 645 operands, 646 {}, 647 "cond", 648 source_target=self.value, 649 should_flatten_outputs=True, 650 ) 651 652 if not only_consist_of(ret_val, (TensorVariable,)): 653 unimplemented( 654 "Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.", 655 ) 656 return ret_val, ret_treespec, ret_graph, ret_lifted_freevars 657 658 (true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch( 659 True 660 ) 661 true_nn_modules = dict(tx.output.nn_modules) 662 663 ( 664 false_r, 665 false_treespec, 666 false_graph, 667 false_lifted_freevars, 668 ) = speculate_branch(False) 669 false_nn_modules = dict(tx.output.nn_modules) 670 671 same_treespec = _make_inlined(tx, pytree.TreeSpec.__eq__)( 672 true_treespec, false_treespec 673 ) 674 if not same_treespec.as_python_constant(): 675 unimplemented("Expected branches to return the same pytree structure.") 676 677 def diff_meta(tensor_vars1, tensor_vars2): 678 assert all( 679 isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2 680 ) 681 all_diffs = [] 682 for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)): 683 # We check the meta data associated with meta["example_value"] 684 meta1 = _extract_tensor_metadata( 685 var1.proxy.node.meta["example_value"], include_contiguity=False 686 ) 687 meta2 = _extract_tensor_metadata( 688 var2.proxy.node.meta["example_value"], include_contiguity=False 689 ) 690 if meta1 != meta2: 691 all_diffs.append((f"pair{i}:", meta1, meta2)) 692 return all_diffs 693 694 if diffs := diff_meta( 695 true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx) 696 ): 697 unimplemented( 698 f"Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:{diffs}" 699 ) 700 701 ( 702 true_graph, 703 false_graph, 704 true_shared, 705 false_shared, 706 unique_true, 707 unique_false, 708 ) = _merge_graph_inputs( 709 true_graph, 710 true_lifted_freevars, 711 "true_branch", 712 false_graph, 713 false_lifted_freevars, 714 "false_branch", 715 ) 716 717 true_name = add_subgraph( 718 tx, 719 "cond_true", 720 torch.fx.GraphModule(true_nn_modules, true_graph), 721 ) 722 false_name = add_subgraph( 723 tx, 724 "cond_false", 725 torch.fx.GraphModule(false_nn_modules, false_graph), 726 ) 727 728 true_node = make_attr(tx, true_name) 729 false_node = make_attr(tx, false_name) 730 731 p_args = ( 732 args[0].as_proxy(), 733 true_node, 734 false_node, 735 # We pick true_shared but it shouldn't matter 736 true_shared + unique_true + unique_false, 737 ) 738 739 flat_example_value = pytree.tree_map_only( 740 torch.fx.Proxy, 741 lambda a: a.node.meta["example_value"], 742 true_r.as_proxy(), 743 ) 744 745 return _call_function_and_unflatten_output( 746 tx, 747 torch.ops.higher_order.cond, 748 p_args, 749 {}, 750 flat_example_value, 751 true_treespec, 752 ) 753 754 755class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable): 756 def __init__(self, hop, source, script_obj_var, method_name): 757 super().__init__(hop, source) 758 self.script_obj_var = script_obj_var 759 self.method_name = method_name 760 761 def call_function( 762 self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] 763 ) -> VariableTracker: 764 from .builder import wrap_fx_proxy 765 766 args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) 767 768 args_proxy = [arg.as_proxy() for arg in args] 769 kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()} 770 return wrap_fx_proxy( 771 tx=tx, 772 proxy=tx.output.create_proxy( 773 "call_function", 774 self.value, 775 args=tuple( 776 [self.script_obj_var.as_proxy(), self.method_name] + args_proxy 777 ), 778 kwargs=kwargs_proxy, 779 ), 780 ) 781 782 783class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable): 784 @raise_hard_error_if_graph_break( 785 reason="while_loop doesn't work unless it is captured completely with torch.compile." 786 ) 787 def call_function( 788 self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] 789 ) -> VariableTracker: 790 from . import TensorVariable 791 792 args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) 793 794 for i, k in enumerate(["cond_fn", "body_fn", "operands"]): 795 if v := kwargs.pop(k, None): 796 assert i == len( 797 args 798 ), "did not provide the right number of non-keyword args" 799 args.append(v) 800 801 if kwargs: 802 unimplemented( 803 f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}" 804 ) 805 806 if len(args) != 4: 807 unimplemented( 808 f"Expected 4 arguments but got {len(args)}.\n" 809 f"Usage: while_loop(cond_fn, body_fn, operands)", 810 ) 811 812 _check_supported_callable_arg(tx, args[0], "cond_fn") 813 _check_supported_callable_arg(tx, args[1], "body_fn") 814 815 # operands 816 if not isinstance(args[2], (ListVariable, TupleVariable)): 817 unimplemented( 818 f"Expected a tuple but got {args[2].python_type()}", 819 ) 820 operands = args[2].unpack_var_sequence(tx) 821 if not only_consist_of(args[2], (TensorVariable,)): 822 unimplemented( 823 "Expect operands to be a tuple of pytrees that only consists of tensor leaves." 824 ) 825 826 # additional inputs check 827 if not isinstance(args[3], (ListVariable, TupleVariable)): 828 unimplemented( 829 f"Expected a tuple but got {args[3].python_type()}", 830 ) 831 additional_inputs = args[3].unpack_var_sequence(tx) 832 833 ( 834 (cond_r, cond_treespec), 835 cond_graph, 836 cond_lifted_freevars, 837 ) = speculate_subgraph( 838 tx, 839 args[0], 840 operands + additional_inputs, 841 {}, 842 "while_loop", 843 source_target=self.value, 844 set_subgraph_inputs="manual", 845 ) 846 cond_nn_modules = dict(tx.output.nn_modules) 847 if not isinstance(cond_r, TensorVariable): 848 unimplemented( 849 f"Expected cond_fn to return a tensor but got {cond_r.python_type()}", 850 ) 851 852 cond_r_meta = _extract_tensor_metadata( 853 cond_r.proxy.node.meta["example_value"], include_contiguity=False 854 ) 855 if not cond_r_meta.dtype == torch.bool or not cond_r_meta.shape == torch.Size( 856 [] 857 ): 858 unimplemented( 859 f"Expected cond_fn to return a tensor with shape (,) but got {cond_r_meta.shape}" 860 ) 861 862 ( 863 (body_r, body_treespec), 864 body_graph, 865 body_lifted_freevars, 866 ) = speculate_subgraph( 867 tx, 868 args[1], 869 operands + additional_inputs, 870 {}, 871 "while_loop", 872 source_target=self.value, 873 set_subgraph_inputs="manual", 874 should_flatten_outputs=True, 875 ) 876 ( 877 cond_graph, 878 body_graph, 879 cond_shared, 880 body_shared, 881 cond_unique, 882 body_unique, 883 ) = _merge_graph_inputs( 884 cond_graph, 885 cond_lifted_freevars, 886 "cond_fn", 887 body_graph, 888 body_lifted_freevars, 889 "body_fn", 890 ) 891 892 # Note: cond_shared and body_shared refer to the same proxy in parent graph 893 # so using either of them is OK. Use cond_shared as it doesnt matter. 894 additional_lifted_inputs = cond_shared + cond_unique + body_unique 895 896 body_nn_modules = dict(tx.output.nn_modules) 897 898 cond_name = add_subgraph( 899 tx, 900 "cond_fn", 901 torch.fx.GraphModule(cond_nn_modules, cond_graph), 902 ) 903 body_name = add_subgraph( 904 tx, 905 "body_fn", 906 torch.fx.GraphModule(body_nn_modules, body_graph), 907 ) 908 909 cond_node = make_attr(tx, cond_name) 910 body_node = make_attr(tx, body_name) 911 912 p_args = ( 913 cond_node, 914 body_node, 915 tuple([operand.as_proxy() for operand in operands]), 916 tuple( 917 [inp.as_proxy() for inp in additional_inputs] + additional_lifted_inputs 918 ), 919 ) 920 921 flat_example_value = pytree.tree_map_only( 922 torch.fx.Proxy, 923 lambda a: a.node.meta["example_value"], 924 body_r.as_proxy(), 925 ) 926 927 return _call_function_and_unflatten_output( 928 tx, 929 torch.ops.higher_order.while_loop, 930 p_args, 931 {}, 932 flat_example_value, 933 body_treespec, 934 ) 935 936 937class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable): 938 @raise_hard_error_if_graph_break( 939 reason="associative_scan must be captured completely with torch.compile." 940 ) 941 def call_function( 942 self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] 943 ) -> VariableTracker: 944 from .builder import SourcelessBuilder, wrap_fx_proxy 945 946 args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) 947 948 def arg_extractor(combine_fn, input, dim): 949 return combine_fn, input, dim 950 951 combine_fn, input, dim = arg_extractor(*args, **kwargs) 952 953 if input.python_type() != list: 954 unimplemented( 955 f"Expected input to be a list of tensors but got {input.python_type()}", 956 ) 957 assert isinstance(input, torch._dynamo.variables.lists.BaseListVariable) 958 959 # Trace the subgraph 960 # TODO: Fix these pointless new_empty calls appearing in the dynamo output graph. 961 null_shape = SourcelessBuilder.create(tx, ()) 962 sub_args = [ 963 leaf.call_method(tx, "new_empty", args=(null_shape,), kwargs={}) 964 for leaf in itertools.chain(input.items, input.items) 965 ] 966 ( 967 (combine_result, combine_treespec), 968 combine_graph, 969 combine_lifted_freevars, 970 ) = speculate_subgraph( 971 tx, 972 combine_fn, 973 sub_args, 974 sub_kwargs={}, 975 description="scan_combine", 976 source_target=self.value, 977 set_subgraph_inputs="flatten_manual", 978 ) 979 980 if combine_lifted_freevars: 981 unimplemented( 982 f"Combine fn had unexpected freevars: {combine_lifted_freevars}" 983 ) 984 985 if combine_result.python_type() != list: 986 unimplemented( 987 f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}", 988 ) 989 990 input_proxy = input.as_proxy() 991 combine_result_proxy = combine_result.as_proxy() 992 for result, inp_proxy in zip(combine_result_proxy, input_proxy): 993 inp_meta = inp_proxy.node.meta["example_value"] 994 combine_result_meta = result.node.meta["example_value"] 995 if combine_result_meta.device != inp_meta.device: 996 unimplemented( 997 f"Expected combine_fn to return a tensor on device {inp_meta.device} but " 998 + f"got {combine_result_meta.device}" 999 ) 1000 if combine_result_meta.dtype != inp_meta.dtype: 1001 unimplemented( 1002 f"Expected combine_fn to return a tensor of {inp_meta.dtype} but " 1003 + f"got {combine_result_meta.dtype}" 1004 ) 1005 1006 if combine_result_meta.shape != (): 1007 unimplemented( 1008 f"Expected combine_fn to return a tensor with shape () but got {combine_result_meta.shape}" 1009 ) 1010 1011 combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph) 1012 combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm) 1013 1014 p_args = ( 1015 make_attr(tx, combine_fn_name), 1016 input_proxy, 1017 dim.as_proxy(), 1018 ) 1019 1020 with tx.fake_mode: 1021 out_meta = tuple( 1022 inp_proxy.node.meta["example_value"].clone() 1023 for inp_proxy in input_proxy 1024 ) 1025 return wrap_fx_proxy( 1026 tx=tx, 1027 proxy=tx.output.create_proxy( 1028 "call_function", torch.ops.higher_order.associative_scan, p_args, {} 1029 ), 1030 example_value=out_meta, 1031 ) 1032 1033 1034def non_single_tensor_return_unsupported(api, ret): 1035 from . import TensorVariable 1036 1037 if not isinstance(ret, TensorVariable): 1038 raise Unsupported( 1039 f"{api} over function that returns something " f"other than one Tensor" 1040 ) 1041 1042 1043class MapHigherOrderVariable(TorchHigherOrderOperatorVariable): 1044 def call_function( 1045 self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] 1046 ) -> VariableTracker: 1047 from . import TensorVariable 1048 from .builder import wrap_fx_proxy_cls 1049 1050 if len(kwargs) > 0: 1051 unimplemented( 1052 "torch.ops.higher_order.map: kwargs are not supported in the map operator." 1053 ) 1054 1055 _check_supported_callable_arg(tx, args[0].realize(), "map_fn") 1056 1057 assert type(args[1].realize()) is TensorVariable 1058 1059 sample_shape = get_fake_value(args[1].as_proxy().node, tx).size() 1060 1061 if len(sample_shape) < 1 or sample_shape[0] == 0: 1062 unimplemented( 1063 "map() operator doesn't support scalar or zero-sized tensors during tracing." 1064 ) 1065 1066 # To get the example output from map() we will need to provide at least one sample to 1067 # the loop body. In our case we will always use xs[0], and our map() won't support zero 1068 # sized tensor during tracing. 1069 first_dim = wrap_fx_proxy_cls( 1070 target_cls=TensorVariable, tx=tx, proxy=args[1].as_proxy()[0] 1071 ) 1072 1073 # TODO: Support kwargs 1074 ( 1075 (body_r, body_spec), 1076 body_graph, 1077 body_lifted_freevars, 1078 ) = speculate_subgraph( 1079 tx, 1080 args[0], 1081 [ 1082 first_dim, 1083 *args[2:], 1084 ], 1085 {}, 1086 "torch.ops.higher_order.map", 1087 source_target=self.value, 1088 set_subgraph_inputs="flatten_manual", 1089 should_flatten_outputs=True, 1090 ) 1091 1092 subgraph_example_value = [ 1093 proxy.node.meta["example_value"] for proxy in body_r.as_proxy() 1094 ] 1095 1096 with tx.output.fake_mode: 1097 # We need to expand the example output from map() so that it has 1098 # the same first dimension as the mapped input. 1099 # We also do a clone with contiguous_format. This is to be consistent with 1100 # eager semantic of map, which stacks the outputs. The result is contiguous 1101 # as a result of the stack operation. 1102 map_example_out = [ 1103 t.expand(sample_shape[0], *t.size()).clone( 1104 memory_format=torch.contiguous_format 1105 ) 1106 for t in subgraph_example_value 1107 ] 1108 1109 body_nn_modules = dict(tx.output.nn_modules) 1110 1111 body_name = add_subgraph( 1112 tx, 1113 "map_body", 1114 torch.fx.GraphModule(body_nn_modules, body_graph), 1115 ) 1116 1117 body_node = make_attr(tx, body_name) 1118 1119 p_args = ( 1120 body_node, 1121 [args[1].as_proxy()], 1122 [arg.as_proxy() for arg in args[2:]] + list(body_lifted_freevars.keys()), 1123 ) 1124 1125 return _call_function_and_unflatten_output( 1126 tx, torch.ops.higher_order.map_impl, p_args, {}, map_example_out, body_spec 1127 ) 1128 1129 1130class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable): 1131 def call_function( 1132 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1133 ) -> "VariableTracker": 1134 from .builder import wrap_fx_proxy 1135 1136 # This is operator for delegation within Executorch which calls a 1137 # specific function in the given lowered module with the given 1138 # operators. The actual operator is defined in the Executorch codebase. 1139 # This is a bad hierarchical violation since 1140 # executorch_call_delegate sits at a higher level than dynamo, but 1141 # there's no real solution to this issue yet. 1142 if len(kwargs) > 0: 1143 unimplemented( 1144 "executorch_call_delegate: kwargs arguments were not enabled." 1145 ) 1146 lowered_module = tx.output.get_submodule(args[0].module_key) 1147 1148 lowered_node = make_attr(tx, args[0].module_key) 1149 1150 p_args = tuple(arg.as_proxy() for arg in args[1:]) 1151 real_sub_args = pytree.tree_map_only( 1152 torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args 1153 ) 1154 1155 example_value = lowered_module.original_module.module()(*real_sub_args) 1156 1157 # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]: 1158 # executorch modules promise not to alias inputs and outputs. 1159 # Thus, output FakeTensors will correctly not alias input FakeTensors. 1160 _assert_tensors_nonaliasing(real_sub_args, example_value) 1161 1162 p_args = (lowered_node,) + p_args 1163 1164 # Store the invocation as a call 1165 return wrap_fx_proxy( 1166 tx=tx, 1167 proxy=tx.output.create_proxy( 1168 "call_function", 1169 self.value, 1170 args=tuple(p_args), 1171 kwargs={}, 1172 ), 1173 example_value=example_value, 1174 ) 1175 1176 1177class FunctorchHigherOrderVariable(UserFunctionVariable): 1178 def call_function( 1179 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1180 ) -> "VariableTracker": 1181 if not torch._dynamo.config.capture_func_transforms: 1182 name = self.get_name() 1183 fn = { 1184 "grad_impl": "grad", 1185 "vmap_impl": "vmap", 1186 "vjp": "vjp", 1187 "jvp": "jvp", 1188 "jacrev": "jacrev", 1189 "jacfwd": "jacfwd", 1190 "hessian": "hessian", 1191 "linearize": "linearize", 1192 }.get(name) 1193 assert name is not None 1194 unimplemented( 1195 f"torch.func.{fn} capture is disabled, " 1196 "it can be turned on by setting " 1197 "`torch._dynamo.config.capture_func_transforms=True`" 1198 ) 1199 return super().call_function(tx, args, kwargs) 1200 1201 1202class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): 1203 def create_wrapped_node(self, tx, args, kwargs, description): 1204 # See NOTE [HigherOrderOperator tracing design] for more details 1205 1206 ( 1207 (body_r, treespec), 1208 body_graph, 1209 body_lifted_freevars, 1210 ) = speculate_subgraph( 1211 tx, 1212 args[0], # function 1213 [*args[1:]], 1214 kwargs, 1215 description, 1216 source_target=self.value, 1217 should_flatten_outputs=True, 1218 ) 1219 1220 body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) 1221 body_name = add_subgraph( 1222 tx, 1223 "wrap_body", 1224 body_gmod, 1225 ) 1226 1227 body_node = make_attr(tx, body_name) 1228 1229 # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, 1230 # all the arguments are lifted. 1231 lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) 1232 1233 proxy_args = (body_node,) + lifted_args 1234 example_value = pytree.tree_map_only( 1235 torch.fx.Proxy, 1236 lambda a: a.node.meta["example_value"], 1237 body_r.as_proxy(), 1238 ) 1239 1240 return proxy_args, {}, example_value, body_r, treespec, body_gmod 1241 1242 def call_function( 1243 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1244 ) -> "VariableTracker": 1245 # This flattens the kwargs into lifted args 1246 p_args, p_kwargs, example_value, body_r, treespec, _ = self.create_wrapped_node( 1247 tx, args, kwargs, "wrap" 1248 ) 1249 1250 if len(p_kwargs) > 0: 1251 unimplemented("kwargs should have been flattened into lifted args") 1252 1253 flat_example_value = pytree.tree_map_only( 1254 torch.fx.Proxy, 1255 lambda a: a.node.meta["example_value"], 1256 body_r.as_proxy(), 1257 ) 1258 1259 return _call_function_and_unflatten_output( 1260 tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec 1261 ) 1262 1263 1264class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable): 1265 def call_function( 1266 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1267 ) -> "VariableTracker": 1268 from .builder import wrap_fx_proxy 1269 1270 if len(kwargs) > 0: 1271 unimplemented("out_dtype does not handle kwargs") 1272 1273 p_args = tuple(arg.as_proxy() for arg in args) 1274 op = p_args[0] 1275 output_dtype = p_args[1] 1276 fake_sub_args = pytree.tree_map_only( 1277 torch.fx.Proxy, lambda a: a.node.meta["example_value"], p_args[2:] 1278 ) 1279 # This is a simplified implementation of this operator just for tracing. 1280 # Actual implementation may also first promote the arguments 1281 example_value = op(*fake_sub_args).to(dtype=output_dtype) 1282 1283 # Store the invocation as a call 1284 return wrap_fx_proxy( 1285 tx=tx, 1286 proxy=tx.output.create_proxy( 1287 "call_function", 1288 self.value, 1289 args=tuple(p_args), 1290 kwargs={}, 1291 ), 1292 example_value=example_value, 1293 ) 1294 1295 1296class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable): 1297 @raise_hard_error_if_graph_break( 1298 reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile." 1299 ) 1300 def call_function( 1301 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1302 ) -> "VariableTracker": 1303 callable = args[0] 1304 1305 unpacked_sequence = args[1].unpack_var_sequence(tx) 1306 # TODO (tmanlaibaatar) support pytree here 1307 for arg in unpacked_sequence: 1308 if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)): 1309 unimplemented("strict_mode HOO only works for flat inputs for now") 1310 1311 if kwargs: 1312 unimplemented( 1313 f"strict_mode HOO received unexpected kwargs: {list(kwargs.keys())}" 1314 ) 1315 1316 ( 1317 (ret_val, ret_treespec), 1318 ret_graph, 1319 ret_lifted_freevars, 1320 ) = speculate_subgraph( 1321 tx, 1322 args[0], 1323 unpacked_sequence, 1324 {}, 1325 "strict_mode", 1326 source_target=self.value, 1327 should_flatten_outputs=True, 1328 ) 1329 1330 strict_mode_nn_modules = dict(tx.output.nn_modules) 1331 1332 strict_mode_name = add_subgraph( 1333 tx, 1334 "strict_mode_body", 1335 torch.fx.GraphModule(strict_mode_nn_modules, ret_graph), 1336 ) 1337 1338 strict_mode_node = make_attr(tx, strict_mode_name) 1339 p_args = ( 1340 strict_mode_node, 1341 tuple(arg for arg in ret_lifted_freevars.keys()), 1342 ) 1343 1344 flat_example_value = pytree.tree_map_only( 1345 torch.fx.Proxy, 1346 lambda a: a.node.meta["example_value"], 1347 ret_val.as_proxy(), 1348 ) 1349 1350 return _call_function_and_unflatten_output( 1351 tx, 1352 torch.ops.higher_order.strict_mode, 1353 p_args, 1354 {}, 1355 flat_example_value, 1356 ret_treespec, 1357 ) 1358 1359 1360class CheckpointHigherOrderVariable(WrapHigherOrderVariable): 1361 def call_function( 1362 self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] 1363 ) -> VariableTracker: 1364 from torch._higher_order_ops.wrap import TagActivationCheckpoint 1365 from torch.utils.checkpoint import noop_context_fn 1366 from .builder import wrap_fx_proxy 1367 1368 context_fn = None 1369 if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn: 1370 ctx = kwargs.pop("context_fn") 1371 if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable): 1372 context_fn = ctx.fn 1373 elif isinstance( 1374 ctx, torch._dynamo.variables.functions.FunctoolsPartialVariable 1375 ): 1376 context_fn = ctx.as_python_constant() 1377 else: 1378 raise NotImplementedError( 1379 f"checkpoint not implemented for {type(ctx)} context_fn" 1380 ) 1381 1382 checkpoint_kwargs, gmod_kwargs = TagActivationCheckpoint.divide_kwargs(kwargs) 1383 1384 # Here we use checkpoint_kwargs (and not gmod kwargs). gmod_kwargs are 1385 # already flattened above and managed inside the fx graph. 1386 ( 1387 p_args, 1388 _, 1389 example_value, 1390 body_r, 1391 treespec, 1392 checkpointed_gmod, 1393 ) = self.create_wrapped_node( 1394 tx, args, gmod_kwargs, "torch.utils.checkpoint.checkpoint" 1395 ) 1396 if context_fn is not None: 1397 checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn 1398 1399 _, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs) 1400 1401 # Store the invocation as a call 1402 variable = wrap_fx_proxy( 1403 tx=tx, 1404 proxy=tx.output.create_proxy( 1405 "call_function", 1406 self.value, 1407 args=tuple(p_args), 1408 kwargs=checkpoint_kwargs, 1409 ), 1410 example_value=example_value, 1411 ) 1412 1413 if treespec is None: 1414 return variable 1415 1416 # Transform variable back into a list (previously made into a tuple by 1417 # speculate_subgraph function) so as to respect the pytree API typing. 1418 variable = BuiltinVariable(list).call_function(tx, [variable], {}) 1419 1420 return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec) 1421 1422 1423class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable): 1424 def call_function( 1425 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1426 ) -> "VariableTracker": 1427 from .builder import wrap_fx_proxy 1428 1429 p_args = tuple(arg.as_proxy() for arg in args) 1430 p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} 1431 return wrap_fx_proxy( 1432 tx=tx, 1433 proxy=tx.output.create_proxy( 1434 "call_function", 1435 self.value, 1436 args=p_args, 1437 kwargs=p_kwargs, 1438 ), 1439 example_value=None, 1440 ) 1441 1442 1443class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable): 1444 """ 1445 Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace 1446 by unwrapping the higher order op and inlining through it. This op 1447 is created by dynamo to survive through AotAutograd, then unwrapped 1448 here in the call to dynamo from compiled autograd. 1449 """ 1450 1451 def call_function( 1452 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1453 ) -> "VariableTracker": 1454 kwargs = dict(kwargs) 1455 fn = kwargs.pop("fn") 1456 return fn.call_function(tx, args, kwargs) 1457 1458 1459class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable): 1460 @staticmethod 1461 def normalize_to_args(args, kwargs): 1462 # input signature is (query, key, value, score_mod, *other_buffers) 1463 # Flatten args and kwargs into lists 1464 flat_args = pytree.tree_flatten(args)[0] 1465 flat_kwargs = pytree.tree_flatten(kwargs)[0] 1466 1467 # Combine the flattened lists 1468 all_args = flat_args + flat_kwargs 1469 return all_args 1470 1471 def create_wrapped_node( 1472 self, tx, query: "VariableTracker", score_function: "VariableTracker" 1473 ): 1474 from torch._higher_order_ops.flex_attention import TransformGetItemToIndex 1475 from .builder import SourcelessBuilder 1476 1477 tx: InstructionTranslator = tx 1478 1479 scores_require_grad: bool = query.requires_grad 1480 score = query.call_method( 1481 tx, 1482 "new_empty", 1483 (SourcelessBuilder.create(tx, []),), 1484 {"requires_grad": SourcelessBuilder.create(tx, scores_require_grad)}, 1485 ) 1486 1487 def create_scalar(): 1488 return query.call_method( 1489 tx, 1490 "new_empty", 1491 (SourcelessBuilder.create(tx, []),), 1492 { 1493 "dtype": SourcelessBuilder.create(tx, torch.int32), 1494 }, 1495 ) 1496 1497 bhmn = [create_scalar() for _ in range(4)] 1498 new_args = [score, *bhmn] 1499 1500 with TransformGetItemToIndex(): 1501 ( 1502 (body_output, body_treespec), 1503 body_graph, 1504 body_lifted_freevars, 1505 ) = speculate_subgraph( 1506 tx, 1507 score_function, 1508 new_args, 1509 {}, # expect only args no kwargs for now 1510 description="flex_attention", 1511 source_target=self.value, 1512 set_subgraph_inputs="flatten_manual", 1513 ) 1514 1515 body_name = add_subgraph( 1516 tx, 1517 "flex_attention", 1518 torch.fx.GraphModule(tx.output.nn_modules, body_graph), 1519 ) 1520 1521 body_node = make_attr(tx, body_name) 1522 1523 # It is possible that the score-mod function captures some free variables that are not 1524 # passed in as arguments. In this case, we need to lift them, which is handled by speculate_subgraph. 1525 # We then need to create proxies for this + the inputs. 1526 1527 lifted_args = tuple(arg for arg in body_lifted_freevars.keys()) 1528 1529 proxy_args = (body_node,) + lifted_args 1530 1531 return proxy_args 1532 1533 def call_function( 1534 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1535 ) -> "VariableTracker": 1536 from .builder import wrap_fx_proxy 1537 1538 query, key, value, score_mod = self.normalize_to_args(args, kwargs) 1539 1540 p_args = self.create_wrapped_node(tx, query, score_mod) 1541 proxied_args = [query, key, value] 1542 1543 # Store the invocation as a call 1544 # Norm_kwargs contains the score_function and we dont want to proxy this because 1545 # Proxying user defined functions is not supported. 1546 inp_args, _ = proxy_args_kwargs(proxied_args, {}) 1547 1548 query_meta = query.as_proxy().node.meta["example_value"] 1549 logsumexp_shape = query_meta.size()[:-1] # [B, H, M] 1550 with torch._guards.TracingContext.try_get().fake_mode: 1551 out_meta = torch.empty_like( 1552 query_meta, memory_format=torch.contiguous_format 1553 ) 1554 lse_meta = query_meta.new_empty(logsumexp_shape, dtype=torch.float32) 1555 example_value = (out_meta, lse_meta) 1556 1557 return wrap_fx_proxy( 1558 tx=tx, 1559 proxy=tx.output.create_proxy( 1560 "call_function", 1561 self.value, 1562 args=inp_args + p_args, 1563 kwargs={}, 1564 ), 1565 example_value=example_value, 1566 ) 1567 1568 1569class AutogradFunctionApplyVariable(VariableTracker): 1570 def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs): 1571 super().__init__(**kwargs) 1572 self.fwd_graph = fwd_graph 1573 self.bwd_graph = bwd_graph 1574 self.parent_source = parent_source 1575 1576 def call_function( 1577 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 1578 ) -> "VariableTracker": 1579 from . import ( 1580 AutogradFunctionContextVariable, 1581 UserDefinedClassVariable, 1582 UserFunctionVariable, 1583 UserMethodVariable, 1584 ) 1585 from .builder import wrap_fx_proxy 1586 1587 """ 1588 Consider the following: 1589 class MySin(torch.autograd.Function): 1590 @staticmethod 1591 def forward(ctx, x): 1592 ctx.save_for_backward(x) 1593 return x.sin() 1594 @staticmethod 1595 def backward(ctx, grad): 1596 x, = ctx.saved_tensors 1597 return grad * x.cos() 1598 We want the resulting graphs to look like: 1599 def fwd(ctx, x): 1600 # (output, saved tensors / attrs) 1601 return (x.sin(), [x]) 1602 # bwd(ctx, grad0, grad1, ..., gradn, *saved_tensors_or_attrs) 1603 def bwd(ctx, grad, x): 1604 return grad * x.cos() 1605 To accomplish this, we're going to: 1606 1. Construct a ctx object 1607 2. (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph on MySin.forward (manually_set_inputs=True) 1608 3. (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph on MySin.backward, while manually setting 1609 the ctx and grad inputs. 1610 4. Manually rewriting the fwd graph's output to be (output, stuff_that_gets_used in bwd_graph) 1611 Getting from 3 to 4 is pretty elegant: stuff_that_gets_used in bwd graph is 1612 just the bwd_freevars returned from speculate_subgraph, assuming MySin.backward 1613 doesn't capture any arguments. 1614 All these steps work if MySin.backward doesn't capture any values. This is a 1615 limitation in general that we should check for. 1616 """ 1617 1618 prev_side_effects = tx.output.side_effects.clone() 1619 fwd_tracer = torch._dynamo.output_graph.SubgraphTracer( 1620 tx.output, 1621 parent=tx.output.current_tracer, 1622 source_target="autograd.Function", 1623 ) 1624 1625 fwd_src = AttrSource(self.parent_source, member="forward") 1626 ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) 1627 if isinstance(self.fwd_graph, types.FunctionType): 1628 fwd_fn = UserFunctionVariable(self.fwd_graph) 1629 fwd_args = [ctx, *args] 1630 elif isinstance(self.fwd_graph, types.MethodType): 1631 fwd_fn = UserMethodVariable( 1632 self.fwd_graph.__func__, 1633 UserDefinedClassVariable(self.fwd_graph.__class__), 1634 ) 1635 fwd_args = [fwd_fn.obj, ctx, *args] 1636 else: 1637 unimplemented("non-function or method") 1638 1639 # Speculate subgraph on the fwd 1640 (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph( 1641 tx, 1642 fwd_fn, 1643 fwd_args, 1644 kwargs, 1645 "autograd.Function", 1646 set_subgraph_inputs="semi_automatic", 1647 restore_side_effects=False, 1648 tracer=fwd_tracer, 1649 ) 1650 1651 if ctx.mutable_local in tx.output.side_effects.store_attr_mutations: 1652 if ( 1653 "_materialize_non_diff_grads" 1654 in tx.output.side_effects.store_attr_mutations[ctx.mutable_local] 1655 ): 1656 unimplemented("NYI") 1657 1658 bwd_tracer = torch._dynamo.output_graph.SubgraphTracer( 1659 tx.output, 1660 parent=fwd_tracer, 1661 source_target="autograd.Function", 1662 ) 1663 1664 # Speculate subgraph on the backward. We make the 1665 # bwd tracer a child of the fwd tracer, because backward may rely on 1666 # tensors/attrs created in the fwd tracer. 1667 1668 if isinstance(fwd_out, variables.BaseListVariable): 1669 bwd_args = [ctx, *fwd_out.items] 1670 else: 1671 bwd_args = [ctx, fwd_out] 1672 1673 bwd_src = AttrSource(self.parent_source, member="backward") 1674 if isinstance(self.bwd_graph, types.FunctionType): 1675 bwd_fn = UserFunctionVariable(self.bwd_graph, source=bwd_src) 1676 elif isinstance(self.bwd_graph, types.MethodType): 1677 bwd_fn = UserMethodVariable( 1678 self.bwd_graph.__func__, 1679 UserDefinedClassVariable(self.bwd_graph.__class__), 1680 source=bwd_src, 1681 ) 1682 bwd_args = [bwd_fn.obj, *bwd_args] 1683 else: 1684 unimplemented("non-function or method") 1685 1686 def is_strict_for(v: VariableTracker): 1687 if isinstance(v, variables.TensorVariable): 1688 # we can be more lax for stuff from forward 1689 return v.proxy.tracer is not fwd_tracer 1690 return True 1691 1692 with tx.output.subtracer(fwd_fn, fwd_tracer), tx.strict_translation_mode( 1693 is_strict_for 1694 ): 1695 (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph( 1696 tx, 1697 bwd_fn, 1698 bwd_args, 1699 kwargs, 1700 "autograd.Function", 1701 enable_grad=False, 1702 set_subgraph_inputs="manual", 1703 restore_side_effects=False, 1704 tracer=bwd_tracer, 1705 ) 1706 1707 # TODO: assert that bwd_graph didn't capture values that were 1708 # not created inside fwd_graph. 1709 1710 # TODO(oulgen): Ideally, we would not do a linear search for output 1711 # node but as things currently are there could be nodes after the 1712 # output node 1713 # This is bug prone as if there's code after the output node, then 1714 # graph.output will append the output at the very end 1715 # This might be a behavior difference 1716 1717 # Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd) 1718 for node in fwd_graph.find_nodes(op="output"): 1719 fwd_graph.erase_node(node) 1720 break 1721 1722 # Because we lift the bwd_freevars as inputs of the bwd_graph, 1723 # we have to manually add the bwd_freevars as output of fwd_graph. 1724 # However, the bwd_freevars got from speculate_subgraph use the Proxies in the bwd_graph, 1725 # we need to convert them to Proxies in the fwd_graph and then generate new fwd_graph output. 1726 fwd_proxy_of_bwd_freevars = [] 1727 for k in bwd_freevars.keys(): 1728 if k in fwd_freevars: 1729 fwd_proxy_of_bwd_freevars.append(fwd_freevars[k]) 1730 else: 1731 fwd_proxy_of_bwd_freevars.append(k) 1732 1733 new_fwd_graph_outputs = (fwd_out.as_proxy(), fwd_proxy_of_bwd_freevars) 1734 new_fwd_graph_outputs = pytree.tree_map(lambda x: x.node, new_fwd_graph_outputs) 1735 fwd_graph.output(new_fwd_graph_outputs) 1736 fwd_graph.lint() 1737 1738 # Store fwd_body 1739 fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate() 1740 fwd_name = add_subgraph( 1741 tx, 1742 "fwd_body", 1743 torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph), 1744 ) 1745 1746 fwd_node = make_attr(tx, fwd_name) 1747 1748 # The type of original args can be arbitrary, but we only support basic type in FX graph. 1749 # So the speculated subgraph input includes original tensor args and the lifted freevars. 1750 # We need to filter out the original tensor args and concat them with the lifted freevars 1751 # to generate the proxy args for the FX call_function node. 1752 filtered_args = [] 1753 # A boolean list to mark if the type of corresponding argument is tensor. 1754 # This is used to determine if a FX node's argument should be an argument of 1755 # ApplyTemplate.forward and if we should skip the output from ApplyTemplate.backward 1756 # at torch._functorch.autograd_function.AutogradFunctionApply. 1757 args_tensor_mask = [False] * len(args) 1758 for i, arg in enumerate(args): 1759 if isinstance(arg, (variables.TensorVariable, variables.SymNodeVariable)): 1760 filtered_args.append(arg) 1761 args_tensor_mask[i] = True 1762 1763 # Rewrite the output of bwd_graph to remove the grad output for the non-Tensor args. 1764 new_bwd_graph_outputs = None 1765 for node in bwd_graph.find_nodes(op="output"): 1766 bwd_graph.erase_node(node) 1767 break 1768 1769 # The same as the above fwd proxies, we need to use the bwd proxies in the bwd_graph 1770 # if some of the output is from fwd_freevars. 1771 bwd_out_proxy = bwd_out.as_proxy() 1772 bwd_proxy_of_fwd_freevars = [] 1773 if isinstance(bwd_out_proxy, (tuple, list)): 1774 for k in bwd_out_proxy: 1775 if k in bwd_freevars: 1776 bwd_proxy_of_fwd_freevars.append(bwd_freevars[k]) 1777 else: 1778 bwd_proxy_of_fwd_freevars.append(k) 1779 else: 1780 if bwd_out_proxy in bwd_freevars: 1781 bwd_proxy_of_fwd_freevars = bwd_freevars[bwd_out_proxy] 1782 else: 1783 bwd_proxy_of_fwd_freevars = bwd_out_proxy 1784 1785 # Remove bwd output for non-Tensor args. 1786 output_proxy = bwd_proxy_of_fwd_freevars 1787 if isinstance(output_proxy, (tuple, list)): 1788 new_bwd_graph_outputs = () 1789 for x, mask in zip(output_proxy, args_tensor_mask): 1790 if mask: 1791 new_bwd_graph_outputs = new_bwd_graph_outputs + (x,) 1792 else: 1793 assert x is None, f"Grad of non-Tensor arg {x} is not None." 1794 else: 1795 new_bwd_graph_outputs = output_proxy 1796 1797 # Update the bwd graph output. 1798 new_bwd_graph_outputs = pytree.tree_map( 1799 lambda x: None if x is None else x.node, new_bwd_graph_outputs 1800 ) 1801 bwd_graph.output(new_bwd_graph_outputs) 1802 bwd_graph.lint() 1803 1804 # Store bwd_body 1805 bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate() 1806 bwd_name = add_subgraph( 1807 tx, 1808 "bwd_body", 1809 torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph), 1810 ) 1811 1812 bwd_node = make_attr(tx, bwd_name) 1813 1814 tx.output.side_effects = prev_side_effects 1815 1816 p_args = ( 1817 fwd_node, 1818 bwd_node, 1819 *([arg.as_proxy() for arg in filtered_args] + list(fwd_freevars.keys())), 1820 ) 1821 example_value = pytree.tree_map_only( 1822 torch.fx.Proxy, 1823 lambda a: a.node.meta["example_value"], 1824 fwd_out.as_proxy(), 1825 ) 1826 1827 # Store the invocation as a call 1828 from torch._functorch.autograd_function import autograd_function_apply 1829 1830 return wrap_fx_proxy( 1831 tx=tx, 1832 proxy=tx.output.create_proxy( 1833 "call_function", 1834 autograd_function_apply, 1835 args=p_args, 1836 kwargs={"args_tensor_mask": args_tensor_mask}, 1837 ), 1838 example_value=example_value, 1839 ) 1840