xref: /aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This module is responsible for transforming functions to be traced into a form
4that is easier for the downstream infra (e.g. Autograd, FX, AOTAutograd analysis)
5to handle.
6
7It does so by:
81. functionalization (including RNG functionalzation)
92. creating a joint graph when required
103. transforming mutations into extra outputs
114. dispatching subclasses
12"""
13
14import warnings
15from contextlib import contextmanager, nullcontext
16from functools import wraps
17from typing import Any, Callable, List, Tuple, Union
18from unittest.mock import patch
19
20import torch
21import torch.fx.traceback as fx_traceback
22import torch.utils._pytree as pytree
23from torch import Tensor
24from torch._decomp.decompositions_for_rng import PhiloxStateTracker
25from torch._guards import detect_fake_mode
26from torch._prims_common import CUDARngStateHelper
27from torch.fx.experimental.proxy_tensor import (
28    maybe_disable_thunkify,
29    maybe_enable_thunkify,
30)
31from torch.fx.experimental.symbolic_shapes import (
32    definitely_false,
33    PropagateUnbackedSymInts,
34    sym_eq,
35)
36from torch.nn.utils import stateless
37
38from .. import config
39from .collect_metadata_analysis import run_functionalized_fw_and_collect_metadata
40from .functional_utils import (
41    from_fun,
42    has_data_mutation,
43    has_metadata_mutation,
44    is_fun,
45    sync_functional_tensor,
46    to_fun,
47    was_inductor_storage_resized,
48)
49from .logging_utils import setup_stacktrace_preservation_hooks
50from .schemas import (
51    AOTConfig,
52    MutationType,
53    OutputType,
54    SubclassMeta,
55    SubclassTracingInfo,
56    ViewAndMutationMeta,
57)
58from .subclass_utils import (
59    create_subclass_meta,
60    remap_unwrapped_subclass_arg_indices,
61    requires_subclass_dispatch,
62    unwrap_tensor_subclasses,
63    wrap_tensor_subclasses_maybe_joint,
64)
65from .utils import maybe_to_fresh_input
66
67
68# This function returns a new function that returns mutated inputs as outputs.
69# if keep_data_input_mutations is set, then we assume that data-only mutations
70# will be left in the graph, and we only return metadata-mutated inputs as outputs.
71def fn_input_mutations_to_outputs(
72    fn: Callable,
73    meta: ViewAndMutationMeta,
74    keep_data_input_mutations: bool,
75) -> Any:
76    @wraps(fn)
77    def inner_fn(*args):
78        outs = fn(*args)
79        assert len(meta.output_info) == len(outs)
80        # The compiled fw will return mutated input tensors, *including* metadata-only mutation.
81        # However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs.
82        # (because data-only input mutations are handled directly in the compiled graph)
83        mutated_inputs_to_return = [
84            x for (i, x) in enumerate(args) if i in meta.mutated_inp_runtime_indices
85        ]
86        return *mutated_inputs_to_return, *outs
87
88    return inner_fn
89
90
91# This function takes in a fn with external aliasing and mutation,
92# and returns a new fn with no external aliasing and mutation,
93# as needed for autograd.
94# The main transformations are:
95# - Return mutated inputs as extra outputs
96# - Clone mutated inputs that require gradients,
97#   because autograd will require us to pass the pre-mutated inputs into autograd.grad
98# - Return intermediate bases of outputs as additional outputs,
99#   needed to appease autograd.Function
100# The new function returns:
101# (1) The updated outputs
102# (2) A boolean mask of len(new_fn_outputs),
103#     that can be used to tell autograd.grad which outputs should get tangents
104#     if we trace the backward.
105def fn_prepped_for_autograd(
106    fn: Callable,
107    meta: ViewAndMutationMeta,
108) -> Any:
109    @wraps(fn)
110    def inner_fn(*args):
111        args_maybe_cloned = [
112            maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args)
113        ]
114
115        outs = fn(*args_maybe_cloned)
116        assert isinstance(outs, (tuple, list))
117        outs = list(outs)
118        assert len(meta.output_info) == len(outs)
119
120        mutated_inputs_to_return = [
121            x
122            for (i, x) in enumerate(args_maybe_cloned)
123            if i in meta.mutated_inp_runtime_indices
124        ]
125
126        intermediate_bases = []
127        for i, (o, info) in enumerate(zip(outs, meta.output_info)):
128            if info.output_type == OutputType.alias_of_intermediate_save_as_output:
129                intermediate_bases.append(o._base)
130
131        assert meta.num_intermediate_bases == len(intermediate_bases)
132
133        # the compiled forward should return (mutated_inputs, user_outs, intermediate_bases)
134        fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases
135
136        # Also return a boolean mask specifying which outputs to this function will be used as tangents
137        mutated_inputs_grad_mask = [
138            meta.input_info[meta.mutated_inp_runtime_indices[i]].mutates_data
139            and meta.input_info[meta.mutated_inp_runtime_indices[i]].requires_grad
140            for (i, x) in enumerate(mutated_inputs_to_return)
141        ]
142
143        # Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw
144        # For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead,
145        # which we *should* send to grad()
146        output_grad_mask = [
147            meta.output_info[i].output_type
148            in [
149                OutputType.non_alias,
150                OutputType.unsafe_view_alias,
151                OutputType.custom_function_view,
152            ]
153            # Also, only tensor outputs should participate in the backward
154            # (in particular, Symint outputs in the forward graph shouldn't get tangents)
155            and issubclass(meta.output_info[i].raw_type, Tensor)
156            and meta.output_info[i].requires_grad
157            for (i, x) in enumerate(outs)
158        ]
159
160        intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))]
161
162        out_grad_mask = (
163            mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask
164        )
165        assert len(out_grad_mask) == len(fw_outs_to_return)
166
167        # Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!)
168        # and not primals (the preserved inputs, pre-mutation, that we pass to grad())
169        # This is annoying: our joint function needs to be aware of functionalization
170        # (syncing mutated inputs before calling autograd.grad())
171        # In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner.
172        for arg in args_maybe_cloned:
173            if not isinstance(arg, Tensor):
174                continue
175            sync_functional_tensor(arg)
176
177        return fw_outs_to_return, out_grad_mask
178
179    return inner_fn
180
181
182# Given a fn, computes the joint.
183# NOTE: fn is expects the following behavior:
184# (1) fn() needs to return a tuple of (outs, mask),
185#     where `mask` tells us which outputs are meant to have tangents.
186#     we don't know this info automatically, because we don't actually want to blindly
187#     compute tangents for every output that requires grad.
188#     Specifically, outputs that alias inputs won't participate in the backward and get tangents.
189# (2) fn() cannot mutate any inputs that require gradient.
190#     otherwise, when we compute autograd.grad(), we will not take those input mutations into account
191#     (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
192def create_joint(fn: Callable, *, aot_config: AOTConfig) -> Any:
193    def inner_fn(primals: List[Any], tangents: List[Any]):
194        outs, tangent_mask = fn(*primals)
195
196        assert len(tangent_mask) == len(outs)
197        outs_to_grad = [
198            o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent
199        ]
200        assert len(outs_to_grad) == len(tangents)
201
202        # Get the inputs that need gradients
203        grad_primals = []
204        inputs_needs_grads = []
205        # Note that we're not using primals here,
206        # being carefully not to pass any mutated inputs into autograd.grad()
207        for p in primals:
208            is_grad_tensor = isinstance(p, Tensor) and p.requires_grad
209            inputs_needs_grads.append(is_grad_tensor)
210            if is_grad_tensor:
211                grad_primals.append(p)
212
213        # Get the outputs that need gradients
214        needed_outs = []
215        needed_tangents = []
216        for out, tangent in zip(outs_to_grad, tangents):
217            if isinstance(out, Tensor) and out.requires_grad:
218                # A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
219                # The issue is that we are sensitive to decomps that don't accurately maintain
220                # their output's _base.shape compared to eager mode, and this helps mitigate a bit.
221                # The not definitely_false is also sketchy; if unbacked
222                # symints are involved, we're just going to assume that the
223                # decomps setup the base shape correctly
224                needed_outs.append(
225                    out
226                    if not definitely_false(sym_eq(out.shape, tangent.shape))
227                    else out.view(tangent.shape)
228                )
229                needed_tangents.append(tangent)
230
231        setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs])
232
233        if config.functionalize_rng_ops:
234            PhiloxStateTracker.mark_beginning_of_backward()
235        backward_out: Tuple[Tensor, ...] = ()
236        # Call the backwards pass
237        if grad_primals:
238            functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
239                torch._C._TorchDispatchModeKey.FUNCTIONAL
240            )
241            if functional_tensor_mode is not None:
242                # Side-Effect Tokens:
243                # We want to have independent chains of tokens for forward and backward.
244                # functional_tensor_mode._tokens is used by both.
245                # We memoize the result tokens of forward in functional_tensor_mode._tokens_forward_output,
246                # to return them as joint graph outputs.
247                # We clean functional_tensor_mode._tokens before backward, to prevent reuse of forward tokens in backward.
248                # Joint graph tracing allows tokens discovery,
249                # So all the tokens in backward will be created and added as a graph inputs during tracing.
250                functional_tensor_mode._tokens_forward_output = (
251                    functional_tensor_mode._tokens
252                )
253                functional_tensor_mode._tokens = {}
254
255            with set_partitioner_tag_is_backward(), fx_traceback.preserve_node_meta():
256                # for full graph export, we always export a joint graph where we assume no tangents are needed.
257                if aot_config.no_tangents:
258                    assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1
259                    backward_out = torch.autograd.grad(
260                        needed_outs,
261                        grad_primals,
262                        allow_unused=True,
263                    )
264                else:
265                    backward_out = torch.autograd.grad(
266                        needed_outs,
267                        grad_primals,
268                        grad_outputs=needed_tangents,
269                        allow_unused=True,
270                    )
271        backward_out_iter = iter(backward_out)
272        return outs, [
273            next(backward_out_iter) if i else None for i in inputs_needs_grads
274        ]
275
276    def inner_fn_with_anomaly(*args):
277        with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
278            warnings.filterwarnings("ignore", "Anomaly Detection has been enabled.")
279            with torch.autograd.detect_anomaly(check_nan=False):
280                return inner_fn(*args)
281
282    return inner_fn_with_anomaly
283
284
285def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any:
286    # Functionalization of rng ops changes the calling convention of the joint graph.
287    # It goes from (primals, tangents) to (seed, offset, primals, tangents)
288    # At runtime, we pass on the current seed and offset. This is hidden from
289    # the user.
290    fake_mode = detect_fake_mode()
291    if fake_mode is None:
292        fake_mode = nullcontext()
293
294    def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"):
295        out = PhiloxStateTracker.get_state_as_tensor()
296        return out
297
298    def override_set_rng_state(x, device: Union[int, str, torch.device] = "cuda"):
299        PhiloxStateTracker.set_state_from_tensor(x)
300
301    def append_rng_offsets(args):
302        if trace_joint:
303            # args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs)
304            # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset)
305            return (
306                (*args[0], PhiloxStateTracker.get_updated_fwd_offset()),
307                (*args[1], PhiloxStateTracker.get_updated_bwd_offset()),
308            )
309        else:
310            # args signature before: Tuple(fwd_outputs)
311            # args signature after: Tuple(fwd_outputs, new_fwd_rng_offset)
312            return (*args, PhiloxStateTracker.get_updated_fwd_offset())
313
314    def traced_joint(
315        primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset
316    ):
317        with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
318            "torch.cuda.set_rng_state", override_set_rng_state
319        ):
320            return append_rng_offsets(func(primals, tangents))
321
322    def traced_forward(*primals_fwd_seed_fwd_base_offset):
323        # The signature is (*primals, seed, offset)
324        with patch("torch.cuda.get_rng_state", override_get_rng_state), patch(
325            "torch.cuda.set_rng_state", override_set_rng_state
326        ):
327            return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2]))
328
329    if trace_joint:
330        # Get the current seed and offset to setup tracing.
331        fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
332            fake_mode
333        )
334        bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
335            fake_mode
336        )
337        PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
338        PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward")
339        return traced_joint, (
340            *args,
341            fwd_seed,
342            fwd_base_offset,
343            bwd_seed,
344            bwd_base_offset,
345        )
346    else:
347        # Get the current seed and offset to setup tracing.
348        fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(
349            fake_mode
350        )
351        PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
352        return traced_forward, (*args, fwd_seed, fwd_base_offset)
353
354
355@contextmanager
356def set_partitioner_tag(tag: str):
357    meta_key = "partitioner_tag"
358    assert fx_traceback.has_preserved_node_meta()
359
360    original_val = fx_traceback.current_meta.get(meta_key, None)
361    fx_traceback.current_meta[meta_key] = tag
362    try:
363        yield
364    finally:
365        fx_traceback.current_meta[meta_key] = original_val
366
367
368def set_partitioner_tag_is_backward():
369    return set_partitioner_tag("is_backward")
370
371
372def set_partitioner_tag_must_be_in_backward():
373    return set_partitioner_tag("must_be_in_backward")
374
375
376# This creates the final function that we want to trace using make_fx(),
377# in both aot_dispatch_autograd and aot_dispatch_base.
378# Preconditions:
379# - fn corresponds to the user's fw function
380# - fn arguments have been flattened, duplicate arguments have been handled
381# - In the returned function, the "primals" arguments *includes* synthetic bases.
382# This function does the work of functionalizing the input function,
383# and performing copy_() calls at the end of the function if `keep_input_mutations` is set.
384# The function returned has signature that is either:
385# (1) "traced_fn(primals: List[Any])" if trace_joint is False
386# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True
387# Returns a new (functionalized) function, and updated arguments to call it with.
388def create_functionalized_fn(
389    fn,
390    args,
391    *,
392    meta: ViewAndMutationMeta,
393    aot_config: AOTConfig,
394    trace_joint: bool,
395) -> Any:
396    @wraps(fn)
397    def _functionalized_f_helper(*args):
398        with maybe_enable_thunkify():
399            # See Note [Disabling Functionalize TLS Above Python Functionalization]
400            disable_above = torch._C._ExcludeDispatchKeyGuard(
401                torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
402            )
403
404            with disable_above:
405                # The functionalization code here can potentially trigger traces
406                # into the graph, but we'd prefer to NOT do this, because if we
407                # trace them now, we will end up with FX nodes that don't have
408                # module stack annotations, which makes unflattener unhappy.
409                # Wrap inputs into functional wrappers
410                f_args = pytree.tree_map(to_fun, args)
411
412                # Run the joint
413                f_outs = fn(*f_args)
414
415            if trace_joint:
416                # We support a limited amount of mutation of graph inputs during the backward pass.
417                # (This is used e.g. by Float8, which needs to update buffers during the backward pass)
418                # Here, we perform extra checks for primals that were mutated in the **backward**
419                # We're doing the checks here instead of doing them with the rest of the input mutation handling because:
420                # - We need to detect inputs that were mutated in the backward **separately** from mutations that happened
421                #   during the forward, because the handling is different: some input mutations from the the forward
422                #   can be only handled in a fw-only runtime epilogue, and in theory if we wanted to handle those same
423                #   types of mutations in the backward we would need a bw-only runtime epilogue.
424                # - We could in theory have our analysis pass differentiate mutations in the fw from mutations in
425                #   the bw by running our analysis first on the fw-only graph, and then on the joint graph. This would
426                #   require an extra round of tracing though, so it's more efficient to do in-line here.
427                assert (
428                    isinstance(args, tuple)
429                    and len(args) == 2
430                    and isinstance(args[0], (list, tuple))
431                )
432                # Only look at mutations that happened to forward inputs (e.g. fw buffers that were saved for bw)
433                primals_before = args[0]
434                primals_after = pytree.tree_map(from_fun, f_args[0])
435                for idx, (f_inpt, before, after, inpt_info) in enumerate(
436                    zip(f_args[0], primals_before, primals_after, meta.input_info)
437                ):
438                    # Store information about mutations in joint(for backward analysis)
439                    joint_mutates_data = has_data_mutation(f_inpt)
440
441                    joint_mutates_metadata = has_metadata_mutation(
442                        f_inpt, before, check_only_storage_mutation=False
443                    )
444
445                    # Ban metadata mutations on fw inputs during the bw
446                    if not inpt_info.mutates_metadata:
447                        assert (
448                            not joint_mutates_metadata
449                        ), "Found a graph input that had its metadata mutated in the backward. This is not supported"
450
451                    # Ban storage resizing on fw inputs during the bw
452                    if not inpt_info.mutation_inductor_storage_resize:
453                        assert not was_inductor_storage_resized(
454                            f_inpt
455                        ), "Found a graph input that had storage resizing in the backward. This is not supported"
456
457                    # Allow data mutations on fw inputs during the bw, but only if they do not require grad
458                    # So we can guarantee that we can keep the mutations in the graph
459                    if (
460                        joint_mutates_data
461                        and not inpt_info.mutates_data
462                        and not inpt_info.mutates_storage_metadata
463                    ):
464                        # Not banning here mutations on inpt_info.requires_grad -
465                        # we'll check at runtime and fail only when backward is under torch.is_grad_enabled (create_graph)
466                        # Add node meta for copy_ for partitioner that this node should be in backward graph.
467                        with torch.fx.traceback.preserve_node_meta(), set_partitioner_tag_must_be_in_backward():
468                            before.copy_(after)
469                        meta.indices_of_inputs_that_requires_grad_with_mutations_in_bw.append(
470                            idx
471                        )
472                # Now that we covered mutations to *forward* inputs during the backward,
473                # we also need to cover mutations to *backward-only* inputs during the backward (e.g. mutation to a grad_out).
474                # Today, we will just error in all cases of this happening unless someone needs us to support it.
475                tangents_before = args[1]
476                tangents_after = pytree.tree_map(from_fun, f_args[1])
477                for f_inpt, before, after in zip(
478                    f_args[1], tangents_before, tangents_after
479                ):
480                    assert not has_metadata_mutation(
481                        f_inpt, before, check_only_storage_mutation=False
482                    ) and not has_data_mutation(
483                        f_inpt
484                    ), "Found an input to the backward that was mutated during the backward pass. This is not supported"
485
486            if aot_config.keep_inference_input_mutations:
487                # Note: This is a bit annoying. There's a layering issue here, where:
488                # (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
489                # (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
490                #     However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
491                #     because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
492                #     This makes it pretty difficult for this logic to operate on synthetic bases.
493                # (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
494                #     (unpacked) input aliases, instead of the synthetic base.
495                # Example case where (3) could be important:
496                #
497                #     def f(x, y):
498                #         x.mul_(2)
499                #         y.mul_(3)
500                #         return x, y
501                #    a = torch.ones(1'000'000)
502                #    x, y = out(a[0:9], a[1:10])
503                #
504                # It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
505                # a giant "updated synthetic base" and copying into a's entire storage.
506                #
507                # For now, we are pessimistically not performing the optimization from (3);
508                # we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
509                # This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
510                # about synthetic bases.
511                for i, (inpt_old, inpt_f) in enumerate(
512                    zip(args, f_args) if not trace_joint else zip(args[0], f_args[0])
513                ):
514                    if not isinstance(inpt_f, torch.Tensor):
515                        continue
516                    assert is_fun(inpt_f)
517                    inpt_new = from_fun(inpt_f)
518                    if (
519                        meta.input_info[i].mutation_type
520                        == MutationType.MUTATED_IN_GRAPH
521                    ):
522                        # See Note [set_() Input Mutations in AOTAutograd]
523                        # all mutations on the input must be under no_grad, so it is safe to put in the graph
524                        # Here, we're saying that if an input experienced a set call, inp.set_(other),
525                        # then we can effectively not have to worry about whether its data was mutated.
526                        # There are 3 cases:
527                        # (1) We mutate inp *after* the set_() call. other is a graph intermediate.
528                        #     In this case, we're not really mutating the input storage of "inp";
529                        #     we're mutating the storage of an intermdiate value (other),
530                        #     and slamming that storage into the input tensor. So no data mutation is necessary.
531                        # (2) We mutate inp *after* the set_() call. other is a graph *input*.
532                        #     In this case, the data mutation will be properly handled in the runtime
533                        #     epilogue during the processing of "other"
534                        # (3) We mutate inp *before* the set_() call.
535                        #     This case is *not* currently handled.
536                        if meta.input_info[i].mutates_storage_metadata:
537                            with torch.no_grad():
538                                inpt_old.set_(inpt_new)
539
540                        # Note [Ordering of resize_() and set_()]
541                        # Importantly: the common usage in FSDP is that we have a dummy parameter
542                        # that sees a set_() and **Then** a resize_().
543                        # We must put those mutations into the graph in the same order,
544                        # Since running them in the opposite order will have different behavior.
545                        # We fully ban resize_() followed by set_() for now, although in principal
546                        # we could support this
547                        if meta.input_info[i].mutation_inductor_storage_resize:
548                            # resizing is not supported on subclasses (we error earlier if this happens)
549                            from torch._subclasses.functional_tensor import (
550                                FunctionalTensor,
551                            )
552
553                            assert isinstance(inpt_f, FunctionalTensor)
554                            old_storage_size = torch._functionalize_get_storage_size(  # type: ignore[attr-defined]
555                                inpt_f.elem, before=True
556                            )
557                            new_storage_size = torch._functionalize_get_storage_size(  # type: ignore[attr-defined]
558                                inpt_f.elem, before=False
559                            )
560                            if old_storage_size != new_storage_size:
561                                assert (
562                                    old_storage_size == 0 or new_storage_size == 0
563                                ), f"""\
564    Encountered a storage resize during tracing on input {i}. Old nbytes={old_storage_size}, new nbytes={new_storage_size}
565    We only support storage resizing on graph inputs as long as the input either starts or ends with a storage size of 0
566    (the case for FSDP)"""
567                                torch.ops.inductor.resize_storage_bytes_(
568                                    inpt_old, new_storage_size
569                                )
570                            if new_storage_size == 0:
571                                # Even if we marked the input as having a data mutation (thus needing a copy_()),
572                                # We should **ignore** it if our input has no storage
573                                # (this can happen if, e.g. we temporarily resize our input, copy data into it,
574                                #  and resize it back down to zero)
575                                continue
576                        # Optimization: if the copy_() is a no-op then don't include it in the graph.
577                        # In theory inductor could optimize this away, however in fsdp, we end up with
578                        # param.copy_(param), where param is a zero-storage-size tensor,
579                        # and running this op in eager mode (using the aot_eager backend) will result in a segfault.
580                        # So we may as well optimize it away here.
581                        if inpt_old is inpt_new:
582                            # (This check needs to be done after putting resize_() in the graph,
583                            # since a resize_(0) doesn't actually change the FunctionalTensor's inner tensor)
584                            continue
585                        # We found an input that had a (data-only) mutation.
586                        # Since keep_input_mutations is set, we need to faithfully apply a copy_()
587                        # so the compiler will see the input mutation in the graph.
588                        if (
589                            meta.input_info[i].mutates_data
590                            and meta.input_info[i].mutations_hidden_from_autograd
591                        ):
592                            # Hidden from autograd = run under no_grad, **and** don't bump VC
593                            # (although if the tensor was created in inference mode, it has no VC)
594                            if inpt_old.is_inference():
595                                maybe_preserve_vc = nullcontext()
596                            else:
597                                maybe_preserve_vc = torch.autograd._unsafe_preserve_version_counter(
598                                    inpt_old  # type: ignore[assignment]
599                                )
600                            with torch.no_grad(), maybe_preserve_vc:
601                                inpt_old.copy_(inpt_new)
602                        elif (
603                            meta.input_info[i].mutates_data
604                            and meta.input_info[
605                                i
606                            ].mutations_under_no_grad_or_inference_mode
607                        ):
608                            # Under no_grad = run under no_grad (we still bump the VC though)
609                            # (inference_mode will also bump the VC, as long as the tensor in question
610                            # was created outside of inference_mode)
611                            with torch.no_grad():
612                                inpt_old.copy_(inpt_new)
613                        elif meta.input_info[i].mutates_data:
614                            inpt_old.copy_(inpt_new)
615
616                # When an output tensor is a functionalized mutated input, and we
617                # were able to move the mutation in to the graph then we can return
618                # the mutated input directly. This prevents duplicating the
619                # tensors contents.
620                flat_outs, outs_spec = pytree.tree_flatten(f_outs)
621                flat_outs = [from_fun(o) for o in flat_outs]
622                num_outs = len(meta.output_info)
623
624                for i, outp in enumerate(flat_outs[:num_outs]):
625                    info = meta.output_info[i]
626                    if info.output_type != OutputType.is_input:
627                        continue
628
629                    assert info.base_idx is not None
630                    if (
631                        meta.input_info[info.base_idx].mutation_type
632                        == MutationType.MUTATED_IN_GRAPH
633                    ):
634                        fw_args = args[0] if trace_joint else args
635                        flat_outs[i] = fw_args[info.base_idx]
636                return pytree.tree_unflatten(flat_outs, outs_spec)
637
638            return pytree.tree_map(from_fun, f_outs)
639
640    # Kinda annoying, but needed to make sure that the fx graph we trace out has "primals"
641    # and "tangents" as its input names (which are special-cased by the partitioner)
642    # TODO (tmanlaibaatar) revisit this if we ever need to turn on non-strict joint graph export
643    def joint_helper(primals, tangents):
644        return _functionalized_f_helper(primals, tangents)
645
646    helper = joint_helper if trace_joint else _functionalized_f_helper
647    if config.functionalize_rng_ops:
648        # Setup the wrapper for functionalization of rng ops
649        helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint)
650
651    return helper, args
652
653
654def handle_effect_tokens_fn(
655    fn,
656    args,
657    *,
658    meta: ViewAndMutationMeta,
659    trace_joint: bool,
660) -> Any:
661    num_tokens = len(meta.tokens)
662
663    @wraps(fn)
664    def inner_fn(*args):
665        # See Note [Disabling Functionalize TLS Above Python Functionalization]
666        disable_above = torch._C._ExcludeDispatchKeyGuard(
667            torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
668        )
669
670        with disable_above:
671            # See Note [Side-Effectful Tokens in AOTAutograd]
672            if trace_joint:
673                assert isinstance(args, tuple) and isinstance(args[0], (list, tuple))
674                tokens = args[0][:num_tokens]
675                assert all(token.numel() == 0 for token in tokens)
676                args = (args[0][num_tokens:], *args[1:])
677            else:
678                tokens = args[:num_tokens]
679                assert all(token.numel() == 0 for token in tokens)
680                args = args[num_tokens:]
681
682            # Populate the current FunctionalTensorMode with the tokens per
683            # operator. See Note [FunctionalTensorMode is Stateful]
684            functional_tensor_mode = torch.utils._python_dispatch._detect_infra_mode(
685                torch._C._TorchDispatchModeKey.FUNCTIONAL
686            )
687            assert functional_tensor_mode is not None
688            f_tokens = pytree.tree_map(to_fun, tokens)
689            for i, k in enumerate(meta.tokens.keys()):
690                functional_tensor_mode._tokens[k] = f_tokens[i]
691
692            # Run the joint
693            outs = fn(*args)
694
695        # Return both the tokens and the outputs
696        # See Note [Side-Effectful Tokens in AOTAutograd]
697        if trace_joint:
698            assert len(outs) == 2
699            assert len(functional_tensor_mode._tokens_forward_output) == num_tokens
700            fwd_out_tokens = functional_tensor_mode._tokens_forward_output.values()
701
702            bwd_out_tokens = functional_tensor_mode._tokens.values()
703
704            f_fwd_out_tokens = [from_fun(t) for t in fwd_out_tokens]
705            f_bwd_out_tokens = [from_fun(t) for t in bwd_out_tokens]
706
707            meta.num_backward_tokens = len(bwd_out_tokens)
708            return ((*f_fwd_out_tokens, *outs[0]), (*outs[1], *f_bwd_out_tokens))
709
710        out_tokens = [from_fun(t) for t in functional_tensor_mode._tokens.values()]
711        return (*out_tokens, *outs)
712
713    # Additionally pass in tokens as inputs
714    # See Note [Side-Effectful Tokens in AOTAutograd]
715    additional_fwd_token_inputs = [torch.tensor([])] * num_tokens
716
717    if trace_joint:
718        args = ([*additional_fwd_token_inputs, *args[0]], *args[1:])
719    else:
720        args = [*additional_fwd_token_inputs, *args]
721    return inner_fn, args
722
723
724# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor
725# Also returns:
726# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated)
727# - the updated ViewAndMutationMeta for this dense -> dense function.
728# The other important arguments are:
729# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function.
730#                        when is_joint_structure=False, this is just the forward function.
731# - fw_only: this is *always* the forward-only function.
732#   Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions.
733#   In particular, we need this to tell the partitioner how many dense forward outputs there are.
734def aot_dispatch_subclass(
735    flat_fn_maybe_joint,
736    args: List[Any],
737    *,
738    is_joint_structure: bool,
739    meta: ViewAndMutationMeta,
740    fw_only: Callable,
741) -> SubclassTracingInfo:
742    # Skip logic if we don't need to trace through any subclasses
743    req_subclass_dispatch = requires_subclass_dispatch(args, meta)
744    if not req_subclass_dispatch:
745        return SubclassTracingInfo(
746            plain_tensor_trace_fn=flat_fn_maybe_joint,
747            plain_tensor_args=args,
748            maybe_subclass_meta=None,
749        )
750
751    # TODO: add subclass guards (later PR).
752
753    # What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs).
754    # Annoying: we don't know the grad input metas until we're in the middle of tracing the joint,
755    # so we set it later, while we're tracing the joint (see inner_fn() below).
756    # Another option would be to run our run_functionalized_fw_and_collect_metadata() function
757    # directly on the joint, but this would hurt compile time (adding yet another pass through the joint).
758    subclass_meta = SubclassMeta()
759
760    def inner_fn(fn, args, *, use_trace_joint: bool):
761        # Step 1: wrap tensor inputs into subclasses if necessary
762        all_args = wrap_tensor_subclasses_maybe_joint(
763            args, is_joint_structure=use_trace_joint, meta=meta
764        )
765
766        # Step 2: call the inner function, with our (maybe subclass) inputs
767        wrapped_outs = fn(*all_args)
768
769        if use_trace_joint:
770            # See Note: [Computing Subclass Metadata about grad_inputs]
771            # We also stash subclass info on our grad_inputs, if we're tracing the joint.
772            nonlocal subclass_meta
773            assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2
774            # Don't need fw outs since we already have subclass metadata on them
775            grad_inputs = wrapped_outs[1]
776            subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs)
777
778        # Step 3: Unwrap any subclass outputs back into dense tensors
779        unwrapped_outs = unwrap_tensor_subclasses(
780            wrapped_outs, is_joint_structure=use_trace_joint
781        )
782        return unwrapped_outs
783
784    def joint_fn(primals, tangents):
785        with maybe_enable_thunkify():
786            return inner_fn(
787                flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True
788            )
789
790    def fw_fn(*primals):
791        with maybe_enable_thunkify():
792            return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)
793
794    def metadata_fn(*primals):
795        return inner_fn(fw_only, primals, use_trace_joint=False)
796
797    args_unwrapped = unwrap_tensor_subclasses(
798        args, is_joint_structure=is_joint_structure
799    )
800    remapped_static_indices = remap_unwrapped_subclass_arg_indices(
801        args, meta.static_input_indices
802    )
803
804    if is_joint_structure:
805        primals_unwrapped = args_unwrapped[0]
806        fn_to_trace = joint_fn
807    else:
808        primals_unwrapped = args_unwrapped
809        fn_to_trace = fw_fn
810
811    # Note: [Partitioner handling for Subclasses, Part 1]
812    # The way the partitioner works is that:
813    # (1) we pass is a single graph containing the joint fw/bw,
814    #     where the # of graph outputs corresponds to # fw_outputs + # grad_inputs
815    # (2) The partitioner accepts an arguments, num_fwd_outputs,
816    #     and assumes that the first "num_fwd_outputs" graph outputs correspond
817    #     to outputs of the forward graph.
818    # How do tensor subclasses enter the picture?
819    # the num_fwd_outputs in the final graph is actually non-trivial to compute,
820    # because it can be influenced by input mutations and intermediate bases.
821    # So we compute it by inspecting the current ViewAndMutationMeta object.
822    # However, the original ViewAndMutationMeta that we computed was created
823    # on the subclass -> subclass graph,
824    # which can have a different number of outputs than the dense -> dense graph.
825    # That's why we createa a fresh metadata object on the dense -> dense function here,
826    # and plumb it back up to the partitioner.
827    # See Note: [Partitioner handling for Subclasses, Part 2] for more info.
828    meta_updated = run_functionalized_fw_and_collect_metadata(
829        metadata_fn,
830        static_input_indices=remapped_static_indices,
831        keep_input_mutations=meta.keep_input_mutations,
832        is_train=meta.is_train,
833    )(*primals_unwrapped)
834
835    subclass_meta.fw_metadata = meta_updated
836
837    return SubclassTracingInfo(
838        plain_tensor_trace_fn=fn_to_trace,
839        plain_tensor_args=args_unwrapped,
840        maybe_subclass_meta=subclass_meta,
841    )
842
843
844def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
845    # Redundant with dynamo, but worth having in case this gets invoked elsewhere.
846    # https://github.com/pytorch/pytorch/issues/103569
847
848    def functional_call(*args, **kwargs):
849        with stateless._reparametrize_module(
850            mod, pytree.tree_unflatten(args[:params_len], params_spec)
851        ), maybe_disable_thunkify():
852            if isinstance(mod, torch.fx.GraphModule):
853                with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
854                    warnings.filterwarnings(
855                        "ignore", "Anomaly Detection has been enabled."
856                    )
857                    with torch.autograd.detect_anomaly(check_nan=False):
858                        detect_fake_mode().epoch += 1
859                        out = PropagateUnbackedSymInts(mod).run(
860                            *args[params_len:], **kwargs
861                        )
862            else:
863                out = mod(*args[params_len:], **kwargs)
864
865        if not isinstance(out, (tuple, list)):
866            raise RuntimeError(
867                "Graph output must be a (). This is so that we can avoid "
868                "pytree processing of the outputs. Please change the module to "
869                "have tuple outputs or use aot_module instead."
870            )
871        return out
872
873    # Note [Preserving the nn module stack metadata during export non-strict mode]
874    # This path is currently only used by the non-strict export flow,
875    # where we cannot rely on dynamo to preserve nn stack metadata in our captured graph.
876    # Instead, we stash the original user nn module here, and rely on `make_fx` to grab
877    # this stashed module and use it to track nn module stack metadata
878    if store_orig_mod and not hasattr(functional_call, "_orig_mod"):
879        functional_call._orig_mod = mod  # type: ignore[attr-defined]
880
881    return functional_call
882