xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/higher_order_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import contextlib
4import functools
5import itertools
6import logging
7import types
8
9from typing import Dict, List, Optional, TYPE_CHECKING
10
11import torch._C
12import torch.fx
13import torch.nn
14import torch.onnx.operators
15from torch._dynamo.utils import get_fake_value
16from torch._dynamo.variables import ConstantVariable
17from torch._dynamo.variables.base import VariableTracker
18from torch._dynamo.variables.builtin import BuiltinVariable
19from torch._dynamo.variables.functions import UserFunctionVariable
20from torch._dynamo.variables.tensor import SymNodeVariable
21from torch._guards import Source
22from torch._ops import HigherOrderOperator
23from torch.fx.passes.shape_prop import _extract_tensor_metadata
24from torch.utils import _pytree as pytree
25from .. import variables
26
27from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported
28from ..source import AttrSource
29from ..utils import proxy_args_kwargs
30from .dicts import ConstDictVariable
31from .lazy import LazyVariableTracker
32from .lists import ListVariable, TupleVariable
33
34if TYPE_CHECKING:
35    from torch._dynamo.symbolic_convert import InstructionTranslator
36
37
38log = logging.getLogger(__name__)
39
40
41def raise_hard_error_if_graph_break(reason):
42    def deco(fn):
43        @functools.wraps(fn)
44        def graph_break_as_hard_error(*args, **kwargs):
45            try:
46                return fn(*args, **kwargs)
47            except Unsupported as e:
48                msg = " Scroll up to find out what causes the graph break."
49                raise UncapturedHigherOrderOpError(reason + msg) from e
50
51        return graph_break_as_hard_error
52
53    return deco
54
55
56@contextlib.contextmanager
57def dynamo_enable_grad(tx, enable=True):
58    from . import GradModeVariable
59
60    org_value = torch.is_grad_enabled()
61    try:
62        GradModeVariable.create(tx, enable, initialized=True)
63        yield
64    finally:
65        GradModeVariable.create(tx, org_value, initialized=True)
66
67
68def only_consist_of(var, types, allow_none=False):
69    if isinstance(var, types):
70        return True
71    if allow_none and var.is_python_constant() and var.as_python_constant() is None:
72        return True
73    if isinstance(var, (TupleVariable, ListVariable)):
74        return all(only_consist_of(item, types, allow_none) for item in var.items)
75    if isinstance(var, ConstDictVariable):
76        return all(
77            only_consist_of(item, types, allow_none) for item in var.items.values()
78        )
79    return False
80
81
82# A more read-able syntax sugar for creating a UserFunctionVariable for f
83# and run call_function on it. Make it return a function to preserve the calling
84# convention of the original f.
85def _make_inlined(tx, f):
86    assert callable(f), "Expect f to be a python callable."
87
88    def inline_call(*args, **kwargs):
89        return UserFunctionVariable(f).call_function(tx, args, kwargs)
90
91    return inline_call
92
93
94def _call_function_and_unflatten_output(
95    tx, fn, args, kwargs, flat_example_value, ret_treespec
96):
97    from .builder import wrap_fx_proxy
98
99    # Store the invocation as a call
100    flat_variable = wrap_fx_proxy(
101        tx=tx,
102        proxy=tx.output.create_proxy(
103            "call_function",
104            fn,
105            args=args,
106            kwargs=kwargs,
107        ),
108        example_value=flat_example_value,
109    )
110
111    # Transform variable back into a list (previously made into a tuple by
112    # speculate_subgraph function) so as to respect the pytree API typing.
113    flat_list_variable = BuiltinVariable(list).call_function(tx, [flat_variable], {})
114    return (
115        _make_inlined(tx, pytree.tree_unflatten)(flat_list_variable, ret_treespec)
116        if ret_treespec
117        else flat_variable
118    )
119
120
121def _assert_tensors_nonaliasing(inputs, outputs):
122    input_tensor_ids = {
123        id(t) for t in pytree.tree_leaves(inputs) if isinstance(t, torch.Tensor)
124    }
125    output_tensor_ids = {
126        id(t) for t in pytree.tree_leaves(outputs) if isinstance(t, torch.Tensor)
127    }
128    assert input_tensor_ids.isdisjoint(
129        output_tensor_ids
130    ), "inputs to function body cannot alias outputs"
131
132
133def _check_supported_callable_arg(tx, func_var: VariableTracker, arg_name):
134    is_callable = (
135        BuiltinVariable(callable).call_function(tx, [func_var], {}).as_python_constant()
136    )
137    if not is_callable:
138        unimplemented(f"{arg_name} is of unsupported callable type {str(func_var)}.")
139
140
141def validate_args_and_maybe_create_graph_inputs(
142    sub_args,
143    tracer,
144    tx,
145    set_subgraph_inputs,
146    description,
147):
148    from . import AutogradFunctionContextVariable
149    from .builder import wrap_fx_proxy_cls
150
151    assert tracer.parent is not None
152
153    if set_subgraph_inputs == "flatten_manual":
154        flat_args, tree_spec = _make_inlined(tx, pytree.tree_flatten)(
155            ListVariable(sub_args)
156        ).unpack_var_sequence(tx)
157
158        flat_inputs = validate_args_and_maybe_create_graph_inputs(
159            flat_args.unpack_var_sequence(tx),
160            tracer,
161            tx,
162            set_subgraph_inputs="manual",
163            description=description,
164        )
165
166        return _make_inlined(tx, pytree.tree_unflatten)(
167            ListVariable(flat_inputs), tree_spec
168        ).unpack_var_sequence(tx)
169    else:
170        args = []
171        for a in sub_args:
172            assert isinstance(a, VariableTracker)
173            if set_subgraph_inputs == "automatic":
174                args.append(a)
175                continue
176            elif set_subgraph_inputs == "semi_automatic":
177                if isinstance(a, AutogradFunctionContextVariable):
178                    tracer.create_graph_input(a.as_proxy().node.name)
179                elif a.maybe_fx_node() is not None:
180                    node = a.maybe_fx_node()
181                    new_proxy = tracer.create_graph_input(node.name)
182                    example_value = (
183                        node.meta["example_value"]
184                        if "example_value" in node.meta
185                        else None
186                    )
187                    a = wrap_fx_proxy_cls(
188                        target_cls=type(a),
189                        tx=tx,
190                        proxy=new_proxy,
191                        example_value=example_value,
192                    )
193                args.append(a)
194                continue
195
196            if a.is_python_constant():
197                # This arg is not used in the body of the higher order op.
198                # Currently, this new input is added to make the calls
199                # happy, which expect a fixed number of arguments. In
200                # future, we can clean this up.
201                tracer.create_graph_input("const")
202                new_arg = a
203            # Weird special case, we probably want to delete it or fold it
204            # into the next case (of `a` being placeable into a graph)
205            elif isinstance(a, AutogradFunctionContextVariable):
206                tracer.create_graph_input(a.as_proxy().node.name)
207                new_arg = a
208            # If `a` can be put into a graph
209            elif a.maybe_fx_node() is not None:
210                node = a.maybe_fx_node()
211                new_proxy = tracer.create_graph_input(node.name)
212                example_value = (
213                    node.meta["example_value"] if "example_value" in node.meta else None
214                )
215                new_arg = wrap_fx_proxy_cls(
216                    target_cls=type(a),
217                    tx=tx,
218                    proxy=new_proxy,
219                    example_value=example_value,
220                )
221            # If `a` cannot be put into a graph
222            else:
223                # HOPs work much better if they use speculate_subgraph(set_subgraph_inputs="automatic").
224                unimplemented(
225                    f"{description} with body that accepts non-Tensors as input. "
226                    f"Got: {a.python_type()}"
227                )
228            args.append(new_arg)
229        return args
230
231
232# This helper function is used to make sure two graphs share the same input signature. For example,
233# in torch.cond, two branches might lift different set of tensors as inputs. This function helps to
234# dedup the inputs and modify the graphs to take the same set of inputs.
235def _merge_graph_inputs(
236    l_graph, l_lifted_freevars, l_name, r_graph, r_lifted_freevars, r_name
237):
238    def dedup_and_sort_lifted_freevars(l_lifted_freevars, r_lifted_freevars):
239        # The nn module attributes are guaranteed to be registered into the top-level graph module during
240        # higher order op speculation. Therefore, get_attr nodes in two branches with the same
241        # target refer to the same attribute and we can safely deduplicate them with their target.
242        #
243        # Note: ideally, dynamo should just create a single proxy for the same attribute of a nn module. But
244        # true_branch and false_branch belong to two separate tracing contexts, they may register the same
245        # attribute to top level seperately. This creates two get_attr proxies for the same attribute
246        # that have different meta data such as stack_trace (one stack trace for the true_branch,
247        # and the other for false_branch). It seems better to discard the proxy explicitly in cond
248        # than make dynamo create a single proxy for the same get_attr target.
249        def shared_getattrs(l_lifted_proxies, r_lifted_proxies):
250            true_targets = {
251                proxy.node.target: proxy
252                for proxy in l_lifted_proxies
253                if proxy.node.op == "get_attr"
254            }
255            l_shared_getattrs = {}
256            r_shared_getattrs = {}
257
258            for false_proxy in r_lifted_proxies:
259                if (
260                    false_proxy.node.op == "get_attr"
261                    and false_proxy.node.target in true_targets
262                ):
263                    true_proxy = true_targets[false_proxy.node.target]
264                    l_shared_getattrs[true_proxy] = true_proxy
265                    r_shared_getattrs[false_proxy] = true_proxy
266            return l_shared_getattrs, r_shared_getattrs
267
268        l_shared_getattrs, r_shared_getattrs = shared_getattrs(
269            l_lifted_freevars.keys(), r_lifted_freevars.keys()
270        )
271
272        l_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union(
273            l_shared_getattrs.keys()
274        )
275        r_shared_freevars = (l_lifted_freevars.keys() & r_lifted_freevars.keys()).union(
276            r_shared_getattrs.keys()
277        )
278        unique_l_freevars = l_lifted_freevars.keys() - l_shared_freevars
279        unique_r_freevars = r_lifted_freevars.keys() - r_shared_freevars
280
281        def _sort_by_name(vars):
282            return sorted(vars, key=lambda var: var.node.name)
283
284        return (
285            list(_sort_by_name(list(l_shared_freevars))),
286            list(_sort_by_name(list(r_shared_freevars))),
287            list(_sort_by_name(list(unique_l_freevars))),
288            list(_sort_by_name(list(unique_r_freevars))),
289        )
290
291    (l_shared, r_shared, unique_l, unique_r) = dedup_and_sort_lifted_freevars(
292        l_lifted_freevars, r_lifted_freevars
293    )
294
295    # Let's say we capture cond(pred, true_fn, false_fn, (x,))
296    # With set_graph_input set to automatic,
297    # true_fn has lifted variables x, a, b, c
298    # false_fn has lifted variables x, a, b, d
299    # Then fixup_branch_inps make sure both branches have the same signature, i.e.:
300    # - true_fn(x, a, b, c_true_branch, d_false_branch)
301    # - false_fn(x, a, b, c_true_branch, d_false_branch)
302    #
303    # More formally, the signature has three parts in the following order:
304    # 1. used in both branches: x, a, b
305    # 2. only used in true branches: c, suffixed with _true_branch
306    # 3. only used in false branches: d, suffixed with _false_branch
307    # Within each part, we re-order the nodes by name to have a derterministic ordering for testing.
308    def fixup_branch_inps(graph, lifted_freevars, shared, unique_l, unique_r):
309        def _insert_or_replace_phs(new_args, name_suffix):
310            for arg in new_args:
311                new_ph = graph.placeholder(arg.node.name + name_suffix)
312                # Override with new_ph if there exists a old placeholder.
313                if arg in lifted_freevars:
314                    old_ph = lifted_freevars[arg].node
315                    old_ph.replace_all_uses_with(new_ph)
316                    # replace_all_uses_with doesn't clean users. Clean it mannually so that we could erase it.
317                    old_ph.users = {}
318                    graph.erase_node(old_ph)
319
320        first_not_ph_node = next(
321            node for node in graph.nodes if node.op != "placeholder"
322        )
323        with graph.inserting_before(first_not_ph_node):
324            _insert_or_replace_phs(shared, "")
325            _insert_or_replace_phs(unique_l, "_" + l_name)
326            _insert_or_replace_phs(unique_r, "_" + r_name)
327
328    fixup_branch_inps(l_graph, l_lifted_freevars, l_shared, unique_l, unique_r)
329    fixup_branch_inps(r_graph, r_lifted_freevars, r_shared, unique_l, unique_r)
330    return l_graph, r_graph, l_shared, r_shared, unique_l, unique_r
331
332
333# See NOTE [HigherOrderOperator tracing design] for details of the design
334def speculate_subgraph(
335    tx,
336    f,
337    sub_args,
338    sub_kwargs,
339    description,
340    *,
341    # source_target is the .value of HigherOrderOpVariable and is the
342    # target of the proxy that we created for the higherOrderOperator.
343    source_target=None,
344    always_restore=False,
345    enable_grad=None,
346    # NOTE [argument `set_subgraph_inputs`]
347    # set_subgraph_inputs controls what how to construct subgraphs' placeholders from sub_args.
348    # 1. if your HOP supports arbitrary inputs, use set_subgraph_inputs="automatic" (most recommended).
349    # 2. if your HOP supports only Tensor and symnode inputs, use set_subgraph_inputs="flatten_manual" (recommended).
350    # If sub_args contain Pytree structure (e.g. dict/list/tuple/set), the sub_args will be flattened first.
351    # Then the flattened args are manually set as subgraph's placeholders.
352    # 3. if your HOP must preserve inputs that are not tensor or symnode as placeholders e.g. AutogradFunctionContextVariable
353    # use set_subgraph_inputs="manual" (not recommended). We do not recommend it in general because it has the
354    # restriction that user need to manually control how to create placeholders and VariableTrackers for the args.
355    set_subgraph_inputs="automatic",
356    restore_side_effects=True,
357    should_flatten_outputs=False,
358    # Pass in an originating tracer - this is needed for preserving context
359    # across fwd-bwd for autograd.Function
360    tracer=None,
361):
362    if sub_kwargs is None:
363        sub_kwargs = {}
364
365    assert set_subgraph_inputs in {
366        "automatic",
367        "semi_automatic",
368        "flatten_manual",
369        "manual",
370    }, "Please use one of the supported set_subgraph_inputs options."
371
372    # See NOTE [Temporary argument `set_subgraph_inputs`]
373    if sub_kwargs and set_subgraph_inputs != "automatic":
374        unimplemented("Use `set_subgraph_inputs=automatic` when passing `sub_kwargs`.")
375
376    try:
377        # ensure guards on args get installed in parent subgraph
378        f, sub_args, sub_kwargs = LazyVariableTracker.realize_all(
379            (f, sub_args, sub_kwargs),
380        )
381
382        with tx.output.subtracer(source_target, tracer) as subtracer:
383            args = validate_args_and_maybe_create_graph_inputs(
384                sub_args, subtracer, tx, set_subgraph_inputs, description
385            )
386
387            validate_args_and_maybe_create_graph_inputs(
388                sub_kwargs.values(),
389                subtracer,
390                tx,
391                set_subgraph_inputs="automatic",
392                description=description,
393            )
394
395            autograd_ctx = (
396                dynamo_enable_grad(tx, enable_grad)
397                if enable_grad is not None
398                else contextlib.nullcontext()
399            )
400
401            # For handling side effects, we can make an argument that we don't
402            # have to do anything here. The side effects infra does a good job
403            # of graph breaking if we mutate any nonlocal or global variable
404            # while subtracing. As a result if tracing succeeds, side effects
405            # data structure will only contain read-only data structures that
406            # are put there for tracking purposes.
407            # But on the other hand, there is an argument that if we ever write
408            # a new side effect in Dynamo which does not go through the side
409            # effect infra, we can end up in bad state.
410            # Therefore we restore the side effects after tracing. The catch is
411            # that we have to special handle tensor variables. If we have seen a
412            # nonlocal variable tensor during subtracing, we want to keep a
413            # track of that tensor, so that later subtracing or the root tracer
414            # itself does not create a new proxy for the already observed tensor
415            # variable.
416            if restore_side_effects:
417                prev_side_effects = tx.output.side_effects.clone()
418
419            with autograd_ctx:
420                output = f.call_function(tx, args, sub_kwargs)
421
422            if restore_side_effects:
423                new_side_effects = tx.output.side_effects.clone()
424                prev_side_effects.track_tensor_variables_from_runahead_side_effects(
425                    new_side_effects
426                )
427                tx.output.side_effects = prev_side_effects
428
429            treespec = None
430            if should_flatten_outputs:
431                # Flatten the speculated subgraph output.
432                output, treespec = _make_inlined(tx, pytree.tree_flatten)(
433                    output
434                ).unpack_var_sequence(tx)
435                # Actually, transform the list (returned by flatten) into a tuple
436                # for dynamo consistency.
437                output = BuiltinVariable(tuple).call_function(tx, [output], {})
438
439            # Register output to graph
440            # Modeled off of compile_and_call_fx_graph
441            # TODO: support pytree output
442            # We check always_restore because we dont use the output or side effects of always_restore code,
443            # like bwd.
444            if always_restore:
445                # Nothing left to do here
446                return (output, treespec), tx.output.graph, subtracer.lifted_freevars
447            else:
448                from . import TensorVariable
449
450                if not only_consist_of(output, TensorVariable, allow_none=True):
451                    unimplemented(
452                        "HigherOrderOperator body's output must consist of tensors only"
453                    )
454
455                # The output proxies might not belong to this SubgraphTracer
456                # (if they are free variables that were never lifted)
457                # so lift them here.
458                output_proxies = output.as_proxy()
459                output_proxies = pytree.tree_map(
460                    subtracer.maybe_lift_tracked_freevar_to_input, output_proxies
461                )
462
463                tx.output.create_node(
464                    "output",
465                    "output",
466                    (subtracer.create_arg((output_proxies,))),
467                    {},
468                )
469                graph = tx.output.graph
470                graph.lint()
471                lifted_freevars = subtracer.lifted_freevars
472
473                return (
474                    (output, treespec),
475                    graph,
476                    lifted_freevars,
477                )
478
479    except Unsupported as ex:
480        f_name = f"{type(f).__name__}"
481        if isinstance(f, UserFunctionVariable):
482            f_name = f.get_name()
483        msg = (
484            f"speculate_subgraph: while introspecting {description}, we were unable "
485            f"to trace function `{f_name}` into a single graph. This means "
486            f"that Dynamo was unable to prove safety for this API and will "
487            f"fall back to eager-mode PyTorch, which could lead to a slowdown."
488        )
489        log.info(msg)
490        log.info(ex)
491        raise ex
492
493
494def make_attr(tx, name):
495    node = tx.output.create_proxy(
496        "get_attr",
497        name,
498        (),
499        {},
500    )
501    return node
502
503
504def add_subgraph(tx, name, gm):
505    next_name = None
506    i = 0
507    while not next_name:
508        candidate = f"{name}_{i}"
509        if candidate in tx.output.nn_modules:
510            i += 1
511        else:
512            next_name = candidate
513
514    gm.__name__ = next_name
515    gm.torchdynamo_force_dynamic = False
516    # This graph module is not present in the user space, so it can't be
517    # accessed by a source. Set source=None.
518    tx.output.register_attr_or_module(gm, next_name, source=None)
519    return next_name
520
521
522class TorchHigherOrderOperatorVariable(VariableTracker):
523    def __init__(
524        self, value: HigherOrderOperator, source: Optional[Source] = None, **kwargs
525    ):
526        super().__init__(**kwargs)
527        self.value = value
528        self.source = source
529
530    @staticmethod
531    def make(value, source=None, **kwargs):
532        if value.__name__ == "cond":
533            return CondHigherOrderVariable(value, source, **kwargs)
534        elif value.__name__ == "while_loop":
535            return WhileLoopHigherOrderVariable(value, source, **kwargs)
536        elif value.__name__ in ("map", "map_impl"):
537            return MapHigherOrderVariable(value, source, **kwargs)
538        elif value.__name__ == "executorch_call_delegate":
539            return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs)
540        elif value.__name__ == "out_dtype":
541            return OutDtypeHigherOrderVariable(value, source, **kwargs)
542        elif value.__name__ == "wrap":
543            return WrapHigherOrderVariable(value, source, **kwargs)
544        elif value.__name__ == "flex_attention":
545            return TemplatedAttentionHigherOrderVariable(value, source, **kwargs)
546        elif value.__name__ in (
547            "wrap_activation_checkpoint",
548            "tag_activation_checkpoint",
549        ):
550            return CheckpointHigherOrderVariable(value, source, **kwargs)
551        elif value.__name__ == "_export_tracepoint":
552            return ExportTracepointHigherOrderVariable(value, source, **kwargs)
553        elif value.__name__ == "trace_wrapped":
554            return TraceWrappedHigherOrderOperatorVariable(value, source, **kwargs)
555        elif value.__name__ == "strict_mode":
556            return StrictModeHigherOrderVariable(value, source, **kwargs)
557        elif value.__name__ == "associative_scan":
558            return AssociativeScanHigherOrderVariable(value, source, **kwargs)
559        elif value.__name__ == "call_torchbind":
560            return CallTorchbindHigherOrderVariable(value, source, **kwargs)
561        else:
562            unimplemented(f"HigherOrderOperator {value.__name__}")
563
564    def call_function(
565        self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
566    ) -> VariableTracker:
567        unimplemented(f"HigherOrderOperator {self.value.__name__}")
568
569
570class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
571    @raise_hard_error_if_graph_break(
572        reason="Cond doesn't work unless it is captured completely with torch.compile."
573    )
574    def call_function(
575        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
576    ) -> "VariableTracker":
577        from . import ListVariable, TensorVariable
578
579        args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
580
581        for i, k in enumerate(["pred", "true_fn", "false_fn", "operands"]):
582            if v := kwargs.pop(k, None):
583                assert i == len(
584                    args
585                ), "did not provide the right number of non-keyword args"
586                args.append(v)
587
588        if kwargs:
589            unimplemented(f"torch.cond: Got unexpected kwargs: {list(kwargs.keys())}")
590
591        # TODO(voz): Support fake tensor dispatch for recursive
592        # ops - see torch/dispatch/_dispatcher.py
593        if len(args) != 4:
594            unimplemented(
595                f"Expected 4 arguments but got {len(args)}.\n"
596                f"Usage: cond(pred, true_fn, false_fn, operands)",
597            )
598        # predicate
599        if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable):
600            unimplemented(
601                f"Expected pred to be bool or a boolean tensor with single "
602                f"item but got {str(type(args[0]))} "
603                f"with original python type {str(args[0].python_type())}.",
604            )
605
606        # operands
607        if not isinstance(args[3], (ListVariable, TupleVariable)):
608            unimplemented(
609                f"Expected a tuple but got {args[3].python_type()}",
610            )
611        operands = args[3].unpack_var_sequence(tx)
612        if not only_consist_of(args[3], (TensorVariable,)):
613            unimplemented(
614                "Expect operands to be a tuple of pytrees that only consists of tensor leaves."
615            )
616
617        # branches
618        _check_supported_callable_arg(tx, args[1], "true_fn")
619        _check_supported_callable_arg(tx, args[2], "false_fn")
620
621        # Our strategy for tracing the true/false branches of cond
622        # are to checkpoint our graphstate, run the true branch,
623        # roll it back to the checkpoint, and run the false
624        # branch, and then merge the graphstates.  Well, perhaps
625        # "merge" is too strong a word: we mostly assert that
626        # the resulting graphstates have to be the same.
627        #
628        # We only permit guards to diverge (we union the guards from
629        # both branches).  In particular, this means that side
630        # effects are NOT permitted inside true/false branches; this
631        # would be difficult to implement, because of the path
632        # explosion problem.
633
634        def speculate_branch(branch):
635            # NB: 0 is predicate
636            ix = 1 if branch else 2
637            # TODO: Support kwargs
638            (
639                (ret_val, ret_treespec),
640                ret_graph,
641                ret_lifted_freevars,
642            ) = speculate_subgraph(
643                tx,
644                args[ix],
645                operands,
646                {},
647                "cond",
648                source_target=self.value,
649                should_flatten_outputs=True,
650            )
651
652            if not only_consist_of(ret_val, (TensorVariable,)):
653                unimplemented(
654                    "Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.",
655                )
656            return ret_val, ret_treespec, ret_graph, ret_lifted_freevars
657
658        (true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch(
659            True
660        )
661        true_nn_modules = dict(tx.output.nn_modules)
662
663        (
664            false_r,
665            false_treespec,
666            false_graph,
667            false_lifted_freevars,
668        ) = speculate_branch(False)
669        false_nn_modules = dict(tx.output.nn_modules)
670
671        same_treespec = _make_inlined(tx, pytree.TreeSpec.__eq__)(
672            true_treespec, false_treespec
673        )
674        if not same_treespec.as_python_constant():
675            unimplemented("Expected branches to return the same pytree structure.")
676
677        def diff_meta(tensor_vars1, tensor_vars2):
678            assert all(
679                isinstance(var, TensorVariable) for var in tensor_vars1 + tensor_vars2
680            )
681            all_diffs = []
682            for i, (var1, var2) in enumerate(zip(tensor_vars1, tensor_vars2)):
683                # We check the meta data associated with meta["example_value"]
684                meta1 = _extract_tensor_metadata(
685                    var1.proxy.node.meta["example_value"], include_contiguity=False
686                )
687                meta2 = _extract_tensor_metadata(
688                    var2.proxy.node.meta["example_value"], include_contiguity=False
689                )
690                if meta1 != meta2:
691                    all_diffs.append((f"pair{i}:", meta1, meta2))
692            return all_diffs
693
694        if diffs := diff_meta(
695            true_r.unpack_var_sequence(tx), false_r.unpack_var_sequence(tx)
696        ):
697            unimplemented(
698                f"Expected branches to return tensors with same metadata. [(tensor_pair, difference)...]:{diffs}"
699            )
700
701        (
702            true_graph,
703            false_graph,
704            true_shared,
705            false_shared,
706            unique_true,
707            unique_false,
708        ) = _merge_graph_inputs(
709            true_graph,
710            true_lifted_freevars,
711            "true_branch",
712            false_graph,
713            false_lifted_freevars,
714            "false_branch",
715        )
716
717        true_name = add_subgraph(
718            tx,
719            "cond_true",
720            torch.fx.GraphModule(true_nn_modules, true_graph),
721        )
722        false_name = add_subgraph(
723            tx,
724            "cond_false",
725            torch.fx.GraphModule(false_nn_modules, false_graph),
726        )
727
728        true_node = make_attr(tx, true_name)
729        false_node = make_attr(tx, false_name)
730
731        p_args = (
732            args[0].as_proxy(),
733            true_node,
734            false_node,
735            # We pick true_shared but it shouldn't matter
736            true_shared + unique_true + unique_false,
737        )
738
739        flat_example_value = pytree.tree_map_only(
740            torch.fx.Proxy,
741            lambda a: a.node.meta["example_value"],
742            true_r.as_proxy(),
743        )
744
745        return _call_function_and_unflatten_output(
746            tx,
747            torch.ops.higher_order.cond,
748            p_args,
749            {},
750            flat_example_value,
751            true_treespec,
752        )
753
754
755class CallTorchbindHigherOrderVariable(TorchHigherOrderOperatorVariable):
756    def __init__(self, hop, source, script_obj_var, method_name):
757        super().__init__(hop, source)
758        self.script_obj_var = script_obj_var
759        self.method_name = method_name
760
761    def call_function(
762        self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
763    ) -> VariableTracker:
764        from .builder import wrap_fx_proxy
765
766        args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
767
768        args_proxy = [arg.as_proxy() for arg in args]
769        kwargs_proxy = {k: v.as_proxy() for k, v in kwargs.items()}
770        return wrap_fx_proxy(
771            tx=tx,
772            proxy=tx.output.create_proxy(
773                "call_function",
774                self.value,
775                args=tuple(
776                    [self.script_obj_var.as_proxy(), self.method_name] + args_proxy
777                ),
778                kwargs=kwargs_proxy,
779            ),
780        )
781
782
783class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
784    @raise_hard_error_if_graph_break(
785        reason="while_loop doesn't work unless it is captured completely with torch.compile."
786    )
787    def call_function(
788        self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
789    ) -> VariableTracker:
790        from . import TensorVariable
791
792        args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
793
794        for i, k in enumerate(["cond_fn", "body_fn", "operands"]):
795            if v := kwargs.pop(k, None):
796                assert i == len(
797                    args
798                ), "did not provide the right number of non-keyword args"
799                args.append(v)
800
801        if kwargs:
802            unimplemented(
803                f"torch.while_loop: Got unexpected kwargs: {list(kwargs.keys())}"
804            )
805
806        if len(args) != 4:
807            unimplemented(
808                f"Expected 4 arguments but got {len(args)}.\n"
809                f"Usage: while_loop(cond_fn, body_fn, operands)",
810            )
811
812        _check_supported_callable_arg(tx, args[0], "cond_fn")
813        _check_supported_callable_arg(tx, args[1], "body_fn")
814
815        # operands
816        if not isinstance(args[2], (ListVariable, TupleVariable)):
817            unimplemented(
818                f"Expected a tuple but got {args[2].python_type()}",
819            )
820        operands = args[2].unpack_var_sequence(tx)
821        if not only_consist_of(args[2], (TensorVariable,)):
822            unimplemented(
823                "Expect operands to be a tuple of pytrees that only consists of tensor leaves."
824            )
825
826        # additional inputs check
827        if not isinstance(args[3], (ListVariable, TupleVariable)):
828            unimplemented(
829                f"Expected a tuple but got {args[3].python_type()}",
830            )
831        additional_inputs = args[3].unpack_var_sequence(tx)
832
833        (
834            (cond_r, cond_treespec),
835            cond_graph,
836            cond_lifted_freevars,
837        ) = speculate_subgraph(
838            tx,
839            args[0],
840            operands + additional_inputs,
841            {},
842            "while_loop",
843            source_target=self.value,
844            set_subgraph_inputs="manual",
845        )
846        cond_nn_modules = dict(tx.output.nn_modules)
847        if not isinstance(cond_r, TensorVariable):
848            unimplemented(
849                f"Expected cond_fn to return a tensor but got {cond_r.python_type()}",
850            )
851
852        cond_r_meta = _extract_tensor_metadata(
853            cond_r.proxy.node.meta["example_value"], include_contiguity=False
854        )
855        if not cond_r_meta.dtype == torch.bool or not cond_r_meta.shape == torch.Size(
856            []
857        ):
858            unimplemented(
859                f"Expected cond_fn to return a tensor with shape (,) but got {cond_r_meta.shape}"
860            )
861
862        (
863            (body_r, body_treespec),
864            body_graph,
865            body_lifted_freevars,
866        ) = speculate_subgraph(
867            tx,
868            args[1],
869            operands + additional_inputs,
870            {},
871            "while_loop",
872            source_target=self.value,
873            set_subgraph_inputs="manual",
874            should_flatten_outputs=True,
875        )
876        (
877            cond_graph,
878            body_graph,
879            cond_shared,
880            body_shared,
881            cond_unique,
882            body_unique,
883        ) = _merge_graph_inputs(
884            cond_graph,
885            cond_lifted_freevars,
886            "cond_fn",
887            body_graph,
888            body_lifted_freevars,
889            "body_fn",
890        )
891
892        # Note: cond_shared and body_shared refer to the same proxy in parent graph
893        # so using either of them is OK. Use cond_shared as it doesnt matter.
894        additional_lifted_inputs = cond_shared + cond_unique + body_unique
895
896        body_nn_modules = dict(tx.output.nn_modules)
897
898        cond_name = add_subgraph(
899            tx,
900            "cond_fn",
901            torch.fx.GraphModule(cond_nn_modules, cond_graph),
902        )
903        body_name = add_subgraph(
904            tx,
905            "body_fn",
906            torch.fx.GraphModule(body_nn_modules, body_graph),
907        )
908
909        cond_node = make_attr(tx, cond_name)
910        body_node = make_attr(tx, body_name)
911
912        p_args = (
913            cond_node,
914            body_node,
915            tuple([operand.as_proxy() for operand in operands]),
916            tuple(
917                [inp.as_proxy() for inp in additional_inputs] + additional_lifted_inputs
918            ),
919        )
920
921        flat_example_value = pytree.tree_map_only(
922            torch.fx.Proxy,
923            lambda a: a.node.meta["example_value"],
924            body_r.as_proxy(),
925        )
926
927        return _call_function_and_unflatten_output(
928            tx,
929            torch.ops.higher_order.while_loop,
930            p_args,
931            {},
932            flat_example_value,
933            body_treespec,
934        )
935
936
937class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
938    @raise_hard_error_if_graph_break(
939        reason="associative_scan must be captured completely with torch.compile."
940    )
941    def call_function(
942        self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
943    ) -> VariableTracker:
944        from .builder import SourcelessBuilder, wrap_fx_proxy
945
946        args, kwargs = LazyVariableTracker.realize_all((args, kwargs))
947
948        def arg_extractor(combine_fn, input, dim):
949            return combine_fn, input, dim
950
951        combine_fn, input, dim = arg_extractor(*args, **kwargs)
952
953        if input.python_type() != list:
954            unimplemented(
955                f"Expected input to be a list of tensors but got {input.python_type()}",
956            )
957        assert isinstance(input, torch._dynamo.variables.lists.BaseListVariable)
958
959        # Trace the subgraph
960        # TODO: Fix these pointless new_empty calls appearing in the dynamo output graph.
961        null_shape = SourcelessBuilder.create(tx, ())
962        sub_args = [
963            leaf.call_method(tx, "new_empty", args=(null_shape,), kwargs={})
964            for leaf in itertools.chain(input.items, input.items)
965        ]
966        (
967            (combine_result, combine_treespec),
968            combine_graph,
969            combine_lifted_freevars,
970        ) = speculate_subgraph(
971            tx,
972            combine_fn,
973            sub_args,
974            sub_kwargs={},
975            description="scan_combine",
976            source_target=self.value,
977            set_subgraph_inputs="flatten_manual",
978        )
979
980        if combine_lifted_freevars:
981            unimplemented(
982                f"Combine fn had unexpected freevars: {combine_lifted_freevars}"
983            )
984
985        if combine_result.python_type() != list:
986            unimplemented(
987                f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}",
988            )
989
990        input_proxy = input.as_proxy()
991        combine_result_proxy = combine_result.as_proxy()
992        for result, inp_proxy in zip(combine_result_proxy, input_proxy):
993            inp_meta = inp_proxy.node.meta["example_value"]
994            combine_result_meta = result.node.meta["example_value"]
995            if combine_result_meta.device != inp_meta.device:
996                unimplemented(
997                    f"Expected combine_fn to return a tensor on device {inp_meta.device} but "
998                    + f"got {combine_result_meta.device}"
999                )
1000            if combine_result_meta.dtype != inp_meta.dtype:
1001                unimplemented(
1002                    f"Expected combine_fn to return a tensor of {inp_meta.dtype} but "
1003                    + f"got {combine_result_meta.dtype}"
1004                )
1005
1006            if combine_result_meta.shape != ():
1007                unimplemented(
1008                    f"Expected combine_fn to return a tensor with shape () but got {combine_result_meta.shape}"
1009                )
1010
1011        combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
1012        combine_fn_name = add_subgraph(tx, "scan_combine", combine_gm)
1013
1014        p_args = (
1015            make_attr(tx, combine_fn_name),
1016            input_proxy,
1017            dim.as_proxy(),
1018        )
1019
1020        with tx.fake_mode:
1021            out_meta = tuple(
1022                inp_proxy.node.meta["example_value"].clone()
1023                for inp_proxy in input_proxy
1024            )
1025        return wrap_fx_proxy(
1026            tx=tx,
1027            proxy=tx.output.create_proxy(
1028                "call_function", torch.ops.higher_order.associative_scan, p_args, {}
1029            ),
1030            example_value=out_meta,
1031        )
1032
1033
1034def non_single_tensor_return_unsupported(api, ret):
1035    from . import TensorVariable
1036
1037    if not isinstance(ret, TensorVariable):
1038        raise Unsupported(
1039            f"{api} over function that returns something " f"other than one Tensor"
1040        )
1041
1042
1043class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
1044    def call_function(
1045        self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
1046    ) -> VariableTracker:
1047        from . import TensorVariable
1048        from .builder import wrap_fx_proxy_cls
1049
1050        if len(kwargs) > 0:
1051            unimplemented(
1052                "torch.ops.higher_order.map: kwargs are not supported in the map operator."
1053            )
1054
1055        _check_supported_callable_arg(tx, args[0].realize(), "map_fn")
1056
1057        assert type(args[1].realize()) is TensorVariable
1058
1059        sample_shape = get_fake_value(args[1].as_proxy().node, tx).size()
1060
1061        if len(sample_shape) < 1 or sample_shape[0] == 0:
1062            unimplemented(
1063                "map() operator doesn't support scalar or zero-sized tensors during tracing."
1064            )
1065
1066        # To get the example output from map() we will need to provide at least one sample to
1067        # the loop body. In our case we will always use xs[0], and our map() won't support zero
1068        # sized tensor during tracing.
1069        first_dim = wrap_fx_proxy_cls(
1070            target_cls=TensorVariable, tx=tx, proxy=args[1].as_proxy()[0]
1071        )
1072
1073        # TODO: Support kwargs
1074        (
1075            (body_r, body_spec),
1076            body_graph,
1077            body_lifted_freevars,
1078        ) = speculate_subgraph(
1079            tx,
1080            args[0],
1081            [
1082                first_dim,
1083                *args[2:],
1084            ],
1085            {},
1086            "torch.ops.higher_order.map",
1087            source_target=self.value,
1088            set_subgraph_inputs="flatten_manual",
1089            should_flatten_outputs=True,
1090        )
1091
1092        subgraph_example_value = [
1093            proxy.node.meta["example_value"] for proxy in body_r.as_proxy()
1094        ]
1095
1096        with tx.output.fake_mode:
1097            # We need to expand the example output from map() so that it has
1098            # the same first dimension as the mapped input.
1099            # We also do a clone with contiguous_format. This is to be consistent with
1100            # eager semantic of map, which stacks the outputs. The result is contiguous
1101            # as a result of the stack operation.
1102            map_example_out = [
1103                t.expand(sample_shape[0], *t.size()).clone(
1104                    memory_format=torch.contiguous_format
1105                )
1106                for t in subgraph_example_value
1107            ]
1108
1109        body_nn_modules = dict(tx.output.nn_modules)
1110
1111        body_name = add_subgraph(
1112            tx,
1113            "map_body",
1114            torch.fx.GraphModule(body_nn_modules, body_graph),
1115        )
1116
1117        body_node = make_attr(tx, body_name)
1118
1119        p_args = (
1120            body_node,
1121            [args[1].as_proxy()],
1122            [arg.as_proxy() for arg in args[2:]] + list(body_lifted_freevars.keys()),
1123        )
1124
1125        return _call_function_and_unflatten_output(
1126            tx, torch.ops.higher_order.map_impl, p_args, {}, map_example_out, body_spec
1127        )
1128
1129
1130class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable):
1131    def call_function(
1132        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1133    ) -> "VariableTracker":
1134        from .builder import wrap_fx_proxy
1135
1136        # This is operator for delegation within Executorch which calls a
1137        # specific function in the given lowered module with the given
1138        # operators. The actual operator is defined in the Executorch codebase.
1139        # This is a bad hierarchical violation since
1140        # executorch_call_delegate sits at a higher level than dynamo, but
1141        # there's no real solution to this issue yet.
1142        if len(kwargs) > 0:
1143            unimplemented(
1144                "executorch_call_delegate: kwargs arguments were not enabled."
1145            )
1146        lowered_module = tx.output.get_submodule(args[0].module_key)
1147
1148        lowered_node = make_attr(tx, args[0].module_key)
1149
1150        p_args = tuple(arg.as_proxy() for arg in args[1:])
1151        real_sub_args = pytree.tree_map_only(
1152            torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args
1153        )
1154
1155        example_value = lowered_module.original_module.module()(*real_sub_args)
1156
1157        # NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]:
1158        # executorch modules promise not to alias inputs and outputs.
1159        # Thus, output FakeTensors will correctly not alias input FakeTensors.
1160        _assert_tensors_nonaliasing(real_sub_args, example_value)
1161
1162        p_args = (lowered_node,) + p_args
1163
1164        # Store the invocation as a call
1165        return wrap_fx_proxy(
1166            tx=tx,
1167            proxy=tx.output.create_proxy(
1168                "call_function",
1169                self.value,
1170                args=tuple(p_args),
1171                kwargs={},
1172            ),
1173            example_value=example_value,
1174        )
1175
1176
1177class FunctorchHigherOrderVariable(UserFunctionVariable):
1178    def call_function(
1179        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1180    ) -> "VariableTracker":
1181        if not torch._dynamo.config.capture_func_transforms:
1182            name = self.get_name()
1183            fn = {
1184                "grad_impl": "grad",
1185                "vmap_impl": "vmap",
1186                "vjp": "vjp",
1187                "jvp": "jvp",
1188                "jacrev": "jacrev",
1189                "jacfwd": "jacfwd",
1190                "hessian": "hessian",
1191                "linearize": "linearize",
1192            }.get(name)
1193            assert name is not None
1194            unimplemented(
1195                f"torch.func.{fn} capture is disabled, "
1196                "it can be turned on by setting "
1197                "`torch._dynamo.config.capture_func_transforms=True`"
1198            )
1199        return super().call_function(tx, args, kwargs)
1200
1201
1202class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
1203    def create_wrapped_node(self, tx, args, kwargs, description):
1204        # See NOTE [HigherOrderOperator tracing design] for more details
1205
1206        (
1207            (body_r, treespec),
1208            body_graph,
1209            body_lifted_freevars,
1210        ) = speculate_subgraph(
1211            tx,
1212            args[0],  # function
1213            [*args[1:]],
1214            kwargs,
1215            description,
1216            source_target=self.value,
1217            should_flatten_outputs=True,
1218        )
1219
1220        body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
1221        body_name = add_subgraph(
1222            tx,
1223            "wrap_body",
1224            body_gmod,
1225        )
1226
1227        body_node = make_attr(tx, body_name)
1228
1229        # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`,
1230        # all the arguments are lifted.
1231        lifted_args = tuple(arg for arg in body_lifted_freevars.keys())
1232
1233        proxy_args = (body_node,) + lifted_args
1234        example_value = pytree.tree_map_only(
1235            torch.fx.Proxy,
1236            lambda a: a.node.meta["example_value"],
1237            body_r.as_proxy(),
1238        )
1239
1240        return proxy_args, {}, example_value, body_r, treespec, body_gmod
1241
1242    def call_function(
1243        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1244    ) -> "VariableTracker":
1245        # This flattens the kwargs into lifted args
1246        p_args, p_kwargs, example_value, body_r, treespec, _ = self.create_wrapped_node(
1247            tx, args, kwargs, "wrap"
1248        )
1249
1250        if len(p_kwargs) > 0:
1251            unimplemented("kwargs should have been flattened into lifted args")
1252
1253        flat_example_value = pytree.tree_map_only(
1254            torch.fx.Proxy,
1255            lambda a: a.node.meta["example_value"],
1256            body_r.as_proxy(),
1257        )
1258
1259        return _call_function_and_unflatten_output(
1260            tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec
1261        )
1262
1263
1264class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
1265    def call_function(
1266        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1267    ) -> "VariableTracker":
1268        from .builder import wrap_fx_proxy
1269
1270        if len(kwargs) > 0:
1271            unimplemented("out_dtype does not handle kwargs")
1272
1273        p_args = tuple(arg.as_proxy() for arg in args)
1274        op = p_args[0]
1275        output_dtype = p_args[1]
1276        fake_sub_args = pytree.tree_map_only(
1277            torch.fx.Proxy, lambda a: a.node.meta["example_value"], p_args[2:]
1278        )
1279        # This is a simplified implementation of this operator just for tracing.
1280        # Actual implementation may also first promote the arguments
1281        example_value = op(*fake_sub_args).to(dtype=output_dtype)
1282
1283        # Store the invocation as a call
1284        return wrap_fx_proxy(
1285            tx=tx,
1286            proxy=tx.output.create_proxy(
1287                "call_function",
1288                self.value,
1289                args=tuple(p_args),
1290                kwargs={},
1291            ),
1292            example_value=example_value,
1293        )
1294
1295
1296class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable):
1297    @raise_hard_error_if_graph_break(
1298        reason="strict_mode HOO doesn't work unless it is captured completely with torch.compile."
1299    )
1300    def call_function(
1301        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1302    ) -> "VariableTracker":
1303        callable = args[0]
1304
1305        unpacked_sequence = args[1].unpack_var_sequence(tx)
1306        # TODO (tmanlaibaatar) support pytree here
1307        for arg in unpacked_sequence:
1308            if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)):
1309                unimplemented("strict_mode HOO only works for flat inputs for now")
1310
1311        if kwargs:
1312            unimplemented(
1313                f"strict_mode HOO received unexpected kwargs: {list(kwargs.keys())}"
1314            )
1315
1316        (
1317            (ret_val, ret_treespec),
1318            ret_graph,
1319            ret_lifted_freevars,
1320        ) = speculate_subgraph(
1321            tx,
1322            args[0],
1323            unpacked_sequence,
1324            {},
1325            "strict_mode",
1326            source_target=self.value,
1327            should_flatten_outputs=True,
1328        )
1329
1330        strict_mode_nn_modules = dict(tx.output.nn_modules)
1331
1332        strict_mode_name = add_subgraph(
1333            tx,
1334            "strict_mode_body",
1335            torch.fx.GraphModule(strict_mode_nn_modules, ret_graph),
1336        )
1337
1338        strict_mode_node = make_attr(tx, strict_mode_name)
1339        p_args = (
1340            strict_mode_node,
1341            tuple(arg for arg in ret_lifted_freevars.keys()),
1342        )
1343
1344        flat_example_value = pytree.tree_map_only(
1345            torch.fx.Proxy,
1346            lambda a: a.node.meta["example_value"],
1347            ret_val.as_proxy(),
1348        )
1349
1350        return _call_function_and_unflatten_output(
1351            tx,
1352            torch.ops.higher_order.strict_mode,
1353            p_args,
1354            {},
1355            flat_example_value,
1356            ret_treespec,
1357        )
1358
1359
1360class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
1361    def call_function(
1362        self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
1363    ) -> VariableTracker:
1364        from torch._higher_order_ops.wrap import TagActivationCheckpoint
1365        from torch.utils.checkpoint import noop_context_fn
1366        from .builder import wrap_fx_proxy
1367
1368        context_fn = None
1369        if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn:
1370            ctx = kwargs.pop("context_fn")
1371            if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable):
1372                context_fn = ctx.fn
1373            elif isinstance(
1374                ctx, torch._dynamo.variables.functions.FunctoolsPartialVariable
1375            ):
1376                context_fn = ctx.as_python_constant()
1377            else:
1378                raise NotImplementedError(
1379                    f"checkpoint not implemented for {type(ctx)} context_fn"
1380                )
1381
1382        checkpoint_kwargs, gmod_kwargs = TagActivationCheckpoint.divide_kwargs(kwargs)
1383
1384        # Here we use checkpoint_kwargs (and not gmod kwargs). gmod_kwargs are
1385        # already flattened above and managed inside the fx graph.
1386        (
1387            p_args,
1388            _,
1389            example_value,
1390            body_r,
1391            treespec,
1392            checkpointed_gmod,
1393        ) = self.create_wrapped_node(
1394            tx, args, gmod_kwargs, "torch.utils.checkpoint.checkpoint"
1395        )
1396        if context_fn is not None:
1397            checkpointed_gmod.meta["_checkpoint_context_fn"] = context_fn
1398
1399        _, checkpoint_kwargs = proxy_args_kwargs([], checkpoint_kwargs)
1400
1401        # Store the invocation as a call
1402        variable = wrap_fx_proxy(
1403            tx=tx,
1404            proxy=tx.output.create_proxy(
1405                "call_function",
1406                self.value,
1407                args=tuple(p_args),
1408                kwargs=checkpoint_kwargs,
1409            ),
1410            example_value=example_value,
1411        )
1412
1413        if treespec is None:
1414            return variable
1415
1416        # Transform variable back into a list (previously made into a tuple by
1417        # speculate_subgraph function) so as to respect the pytree API typing.
1418        variable = BuiltinVariable(list).call_function(tx, [variable], {})
1419
1420        return _make_inlined(tx, pytree.tree_unflatten)(variable, treespec)
1421
1422
1423class ExportTracepointHigherOrderVariable(TorchHigherOrderOperatorVariable):
1424    def call_function(
1425        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1426    ) -> "VariableTracker":
1427        from .builder import wrap_fx_proxy
1428
1429        p_args = tuple(arg.as_proxy() for arg in args)
1430        p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
1431        return wrap_fx_proxy(
1432            tx=tx,
1433            proxy=tx.output.create_proxy(
1434                "call_function",
1435                self.value,
1436                args=p_args,
1437                kwargs=p_kwargs,
1438            ),
1439            example_value=None,
1440        )
1441
1442
1443class TraceWrappedHigherOrderOperatorVariable(TorchHigherOrderOperatorVariable):
1444    """
1445    Handles torch._dynamo._trace_wrapped_higher_order_op.inner_trace
1446    by unwrapping the higher order op and inlining through it.  This op
1447    is created by dynamo to survive through AotAutograd, then unwrapped
1448    here in the call to dynamo from compiled autograd.
1449    """
1450
1451    def call_function(
1452        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1453    ) -> "VariableTracker":
1454        kwargs = dict(kwargs)
1455        fn = kwargs.pop("fn")
1456        return fn.call_function(tx, args, kwargs)
1457
1458
1459class TemplatedAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
1460    @staticmethod
1461    def normalize_to_args(args, kwargs):
1462        # input signature is (query, key, value, score_mod, *other_buffers)
1463        # Flatten args and kwargs into lists
1464        flat_args = pytree.tree_flatten(args)[0]
1465        flat_kwargs = pytree.tree_flatten(kwargs)[0]
1466
1467        # Combine the flattened lists
1468        all_args = flat_args + flat_kwargs
1469        return all_args
1470
1471    def create_wrapped_node(
1472        self, tx, query: "VariableTracker", score_function: "VariableTracker"
1473    ):
1474        from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
1475        from .builder import SourcelessBuilder
1476
1477        tx: InstructionTranslator = tx
1478
1479        scores_require_grad: bool = query.requires_grad
1480        score = query.call_method(
1481            tx,
1482            "new_empty",
1483            (SourcelessBuilder.create(tx, []),),
1484            {"requires_grad": SourcelessBuilder.create(tx, scores_require_grad)},
1485        )
1486
1487        def create_scalar():
1488            return query.call_method(
1489                tx,
1490                "new_empty",
1491                (SourcelessBuilder.create(tx, []),),
1492                {
1493                    "dtype": SourcelessBuilder.create(tx, torch.int32),
1494                },
1495            )
1496
1497        bhmn = [create_scalar() for _ in range(4)]
1498        new_args = [score, *bhmn]
1499
1500        with TransformGetItemToIndex():
1501            (
1502                (body_output, body_treespec),
1503                body_graph,
1504                body_lifted_freevars,
1505            ) = speculate_subgraph(
1506                tx,
1507                score_function,
1508                new_args,
1509                {},  # expect only args no kwargs for now
1510                description="flex_attention",
1511                source_target=self.value,
1512                set_subgraph_inputs="flatten_manual",
1513            )
1514
1515        body_name = add_subgraph(
1516            tx,
1517            "flex_attention",
1518            torch.fx.GraphModule(tx.output.nn_modules, body_graph),
1519        )
1520
1521        body_node = make_attr(tx, body_name)
1522
1523        # It is possible that the score-mod function captures some free variables that are not
1524        # passed in as arguments. In this case, we need to lift them, which is handled by speculate_subgraph.
1525        # We then need to create proxies for this + the inputs.
1526
1527        lifted_args = tuple(arg for arg in body_lifted_freevars.keys())
1528
1529        proxy_args = (body_node,) + lifted_args
1530
1531        return proxy_args
1532
1533    def call_function(
1534        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1535    ) -> "VariableTracker":
1536        from .builder import wrap_fx_proxy
1537
1538        query, key, value, score_mod = self.normalize_to_args(args, kwargs)
1539
1540        p_args = self.create_wrapped_node(tx, query, score_mod)
1541        proxied_args = [query, key, value]
1542
1543        # Store the invocation as a call
1544        # Norm_kwargs contains the score_function and we dont want to proxy this because
1545        # Proxying user defined functions is not supported.
1546        inp_args, _ = proxy_args_kwargs(proxied_args, {})
1547
1548        query_meta = query.as_proxy().node.meta["example_value"]
1549        logsumexp_shape = query_meta.size()[:-1]  # [B, H, M]
1550        with torch._guards.TracingContext.try_get().fake_mode:
1551            out_meta = torch.empty_like(
1552                query_meta, memory_format=torch.contiguous_format
1553            )
1554            lse_meta = query_meta.new_empty(logsumexp_shape, dtype=torch.float32)
1555        example_value = (out_meta, lse_meta)
1556
1557        return wrap_fx_proxy(
1558            tx=tx,
1559            proxy=tx.output.create_proxy(
1560                "call_function",
1561                self.value,
1562                args=inp_args + p_args,
1563                kwargs={},
1564            ),
1565            example_value=example_value,
1566        )
1567
1568
1569class AutogradFunctionApplyVariable(VariableTracker):
1570    def __init__(self, fwd_graph, bwd_graph, parent_source, **kwargs):
1571        super().__init__(**kwargs)
1572        self.fwd_graph = fwd_graph
1573        self.bwd_graph = bwd_graph
1574        self.parent_source = parent_source
1575
1576    def call_function(
1577        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
1578    ) -> "VariableTracker":
1579        from . import (
1580            AutogradFunctionContextVariable,
1581            UserDefinedClassVariable,
1582            UserFunctionVariable,
1583            UserMethodVariable,
1584        )
1585        from .builder import wrap_fx_proxy
1586
1587        """
1588        Consider the following:
1589        class MySin(torch.autograd.Function):
1590            @staticmethod
1591            def forward(ctx, x):
1592                ctx.save_for_backward(x)
1593                return x.sin()
1594            @staticmethod
1595            def backward(ctx, grad):
1596                x, = ctx.saved_tensors
1597                return grad * x.cos()
1598        We want the resulting graphs to look like:
1599        def fwd(ctx, x):
1600            # (output, saved tensors / attrs)
1601            return (x.sin(), [x])
1602        # bwd(ctx, grad0, grad1, ..., gradn, *saved_tensors_or_attrs)
1603        def bwd(ctx, grad, x):
1604            return grad * x.cos()
1605        To accomplish this, we're going to:
1606        1. Construct a ctx object
1607        2. (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph on MySin.forward (manually_set_inputs=True)
1608        3. (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph on MySin.backward, while manually setting
1609        the ctx and grad inputs.
1610        4. Manually rewriting the fwd graph's output to be (output, stuff_that_gets_used in bwd_graph)
1611        Getting from 3 to 4 is pretty elegant: stuff_that_gets_used in bwd graph is
1612        just the bwd_freevars returned from speculate_subgraph, assuming MySin.backward
1613        doesn't capture any arguments.
1614        All these steps work if MySin.backward doesn't capture any values. This is a
1615        limitation in general that we should check for.
1616        """
1617
1618        prev_side_effects = tx.output.side_effects.clone()
1619        fwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
1620            tx.output,
1621            parent=tx.output.current_tracer,
1622            source_target="autograd.Function",
1623        )
1624
1625        fwd_src = AttrSource(self.parent_source, member="forward")
1626        ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
1627        if isinstance(self.fwd_graph, types.FunctionType):
1628            fwd_fn = UserFunctionVariable(self.fwd_graph)
1629            fwd_args = [ctx, *args]
1630        elif isinstance(self.fwd_graph, types.MethodType):
1631            fwd_fn = UserMethodVariable(
1632                self.fwd_graph.__func__,
1633                UserDefinedClassVariable(self.fwd_graph.__class__),
1634            )
1635            fwd_args = [fwd_fn.obj, ctx, *args]
1636        else:
1637            unimplemented("non-function or method")
1638
1639        # Speculate subgraph on the fwd
1640        (fwd_out, _), fwd_graph, fwd_freevars = speculate_subgraph(
1641            tx,
1642            fwd_fn,
1643            fwd_args,
1644            kwargs,
1645            "autograd.Function",
1646            set_subgraph_inputs="semi_automatic",
1647            restore_side_effects=False,
1648            tracer=fwd_tracer,
1649        )
1650
1651        if ctx.mutable_local in tx.output.side_effects.store_attr_mutations:
1652            if (
1653                "_materialize_non_diff_grads"
1654                in tx.output.side_effects.store_attr_mutations[ctx.mutable_local]
1655            ):
1656                unimplemented("NYI")
1657
1658        bwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
1659            tx.output,
1660            parent=fwd_tracer,
1661            source_target="autograd.Function",
1662        )
1663
1664        # Speculate subgraph on the backward. We make the
1665        # bwd tracer a child of the fwd tracer, because backward may rely on
1666        # tensors/attrs created in the fwd tracer.
1667
1668        if isinstance(fwd_out, variables.BaseListVariable):
1669            bwd_args = [ctx, *fwd_out.items]
1670        else:
1671            bwd_args = [ctx, fwd_out]
1672
1673        bwd_src = AttrSource(self.parent_source, member="backward")
1674        if isinstance(self.bwd_graph, types.FunctionType):
1675            bwd_fn = UserFunctionVariable(self.bwd_graph, source=bwd_src)
1676        elif isinstance(self.bwd_graph, types.MethodType):
1677            bwd_fn = UserMethodVariable(
1678                self.bwd_graph.__func__,
1679                UserDefinedClassVariable(self.bwd_graph.__class__),
1680                source=bwd_src,
1681            )
1682            bwd_args = [bwd_fn.obj, *bwd_args]
1683        else:
1684            unimplemented("non-function or method")
1685
1686        def is_strict_for(v: VariableTracker):
1687            if isinstance(v, variables.TensorVariable):
1688                # we can be more lax for stuff from forward
1689                return v.proxy.tracer is not fwd_tracer
1690            return True
1691
1692        with tx.output.subtracer(fwd_fn, fwd_tracer), tx.strict_translation_mode(
1693            is_strict_for
1694        ):
1695            (bwd_out, _), bwd_graph, bwd_freevars = speculate_subgraph(
1696                tx,
1697                bwd_fn,
1698                bwd_args,
1699                kwargs,
1700                "autograd.Function",
1701                enable_grad=False,
1702                set_subgraph_inputs="manual",
1703                restore_side_effects=False,
1704                tracer=bwd_tracer,
1705            )
1706
1707        # TODO: assert that bwd_graph didn't capture values that were
1708        # not created inside fwd_graph.
1709
1710        # TODO(oulgen): Ideally, we would not do a linear search for output
1711        # node but as things currently are there could be nodes after the
1712        # output node
1713        # This is bug prone as if there's code after the output node, then
1714        # graph.output will append the output at the very end
1715        # This might be a behavior difference
1716
1717        # Rewrite the output of fwd_graph to (output, stuff_necessary_for_bwd)
1718        for node in fwd_graph.find_nodes(op="output"):
1719            fwd_graph.erase_node(node)
1720            break
1721
1722        # Because we lift the bwd_freevars as inputs of the bwd_graph,
1723        # we have to manually add the bwd_freevars as output of fwd_graph.
1724        # However, the bwd_freevars got from speculate_subgraph use the Proxies in the bwd_graph,
1725        # we need to convert them to Proxies in the fwd_graph and then generate new fwd_graph output.
1726        fwd_proxy_of_bwd_freevars = []
1727        for k in bwd_freevars.keys():
1728            if k in fwd_freevars:
1729                fwd_proxy_of_bwd_freevars.append(fwd_freevars[k])
1730            else:
1731                fwd_proxy_of_bwd_freevars.append(k)
1732
1733        new_fwd_graph_outputs = (fwd_out.as_proxy(), fwd_proxy_of_bwd_freevars)
1734        new_fwd_graph_outputs = pytree.tree_map(lambda x: x.node, new_fwd_graph_outputs)
1735        fwd_graph.output(new_fwd_graph_outputs)
1736        fwd_graph.lint()
1737
1738        # Store fwd_body
1739        fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
1740        fwd_name = add_subgraph(
1741            tx,
1742            "fwd_body",
1743            torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph),
1744        )
1745
1746        fwd_node = make_attr(tx, fwd_name)
1747
1748        # The type of original args can be arbitrary, but we only support basic type in FX graph.
1749        # So the speculated subgraph input includes original tensor args and the lifted freevars.
1750        # We need to filter out the original tensor args and concat them with the lifted freevars
1751        # to generate the proxy args for the FX call_function node.
1752        filtered_args = []
1753        # A boolean list to mark if the type of corresponding argument is tensor.
1754        # This is used to determine if a FX node's argument should be an argument of
1755        # ApplyTemplate.forward and if we should skip the output from ApplyTemplate.backward
1756        # at torch._functorch.autograd_function.AutogradFunctionApply.
1757        args_tensor_mask = [False] * len(args)
1758        for i, arg in enumerate(args):
1759            if isinstance(arg, (variables.TensorVariable, variables.SymNodeVariable)):
1760                filtered_args.append(arg)
1761                args_tensor_mask[i] = True
1762
1763        # Rewrite the output of bwd_graph to remove the grad output for the non-Tensor args.
1764        new_bwd_graph_outputs = None
1765        for node in bwd_graph.find_nodes(op="output"):
1766            bwd_graph.erase_node(node)
1767            break
1768
1769        # The same as the above fwd proxies, we need to use the bwd proxies in the bwd_graph
1770        # if some of the output is from fwd_freevars.
1771        bwd_out_proxy = bwd_out.as_proxy()
1772        bwd_proxy_of_fwd_freevars = []
1773        if isinstance(bwd_out_proxy, (tuple, list)):
1774            for k in bwd_out_proxy:
1775                if k in bwd_freevars:
1776                    bwd_proxy_of_fwd_freevars.append(bwd_freevars[k])
1777                else:
1778                    bwd_proxy_of_fwd_freevars.append(k)
1779        else:
1780            if bwd_out_proxy in bwd_freevars:
1781                bwd_proxy_of_fwd_freevars = bwd_freevars[bwd_out_proxy]
1782            else:
1783                bwd_proxy_of_fwd_freevars = bwd_out_proxy
1784
1785        # Remove bwd output for non-Tensor args.
1786        output_proxy = bwd_proxy_of_fwd_freevars
1787        if isinstance(output_proxy, (tuple, list)):
1788            new_bwd_graph_outputs = ()
1789            for x, mask in zip(output_proxy, args_tensor_mask):
1790                if mask:
1791                    new_bwd_graph_outputs = new_bwd_graph_outputs + (x,)
1792                else:
1793                    assert x is None, f"Grad of non-Tensor arg {x} is not None."
1794        else:
1795            new_bwd_graph_outputs = output_proxy
1796
1797        # Update the bwd graph output.
1798        new_bwd_graph_outputs = pytree.tree_map(
1799            lambda x: None if x is None else x.node, new_bwd_graph_outputs
1800        )
1801        bwd_graph.output(new_bwd_graph_outputs)
1802        bwd_graph.lint()
1803
1804        # Store bwd_body
1805        bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
1806        bwd_name = add_subgraph(
1807            tx,
1808            "bwd_body",
1809            torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph),
1810        )
1811
1812        bwd_node = make_attr(tx, bwd_name)
1813
1814        tx.output.side_effects = prev_side_effects
1815
1816        p_args = (
1817            fwd_node,
1818            bwd_node,
1819            *([arg.as_proxy() for arg in filtered_args] + list(fwd_freevars.keys())),
1820        )
1821        example_value = pytree.tree_map_only(
1822            torch.fx.Proxy,
1823            lambda a: a.node.meta["example_value"],
1824            fwd_out.as_proxy(),
1825        )
1826
1827        # Store the invocation as a call
1828        from torch._functorch.autograd_function import autograd_function_apply
1829
1830        return wrap_fx_proxy(
1831            tx=tx,
1832            proxy=tx.output.create_proxy(
1833                "call_function",
1834                autograd_function_apply,
1835                args=p_args,
1836                kwargs={"args_tensor_mask": args_tensor_mask},
1837            ),
1838            example_value=example_value,
1839        )
1840