xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/auto_functionalize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import warnings
4from dataclasses import dataclass
5from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
7import torch
8import torch.utils._pytree as pytree
9from torch import Tensor
10from torch._C import DispatchKey
11from torch._ops import HigherOrderOperator, OperatorBase, OpOverload
12from torch._prims_common import clone_preserve_strides
13from torch._subclasses.fake_tensor import FakeTensorMode
14from torch.fx.experimental.proxy_tensor import (
15    disable_proxy_modes_tracing,
16    ProxyTorchDispatchMode,
17    track_tensor_tree,
18)
19
20
21def get_base(tensor):
22    if torch.is_inference_mode_enabled():
23        return tensor._inference_mode_base
24    else:
25        return tensor._base
26
27
28@dataclass
29class ViewInfo:
30    base_index: int
31    size: Optional[Sequence[Union[int, torch.SymInt]]] = None
32    stride: Optional[Sequence[Union[int, torch.SymInt]]] = None
33    storage_offset: Optional[int] = None
34    # When is_view is false, the tensor is the base, and
35    # size, stride and storage_offset are all None.
36    is_view: bool = True
37
38    def regenerate_view(self, bases_list: List[Tensor]):
39        if not self.is_view:
40            return bases_list[self.base_index]
41
42        assert self.stride is not None
43        assert self.size is not None
44        assert self.storage_offset is not None
45
46        return torch.as_strided(
47            bases_list[self.base_index],
48            self.size,
49            self.stride,
50            self.storage_offset,
51        )
52
53
54def write_view_information_to_args(
55    mutable_arg_names: List[str],
56    mutable_arg_types: List[torch.Type],
57    kwargs: Dict[str, Any],
58    arg_to_base_index: Dict[str, Any],
59):
60    """
61    This function writes the view information into kwargs. It reads mutable_args from kwargs.
62    and uses arg_to_base_index and tensor information to write ViewInfo into kwargs.
63    mutable_arg_names: mutable custom operator arg names.
64    mutable_arg_types: mutable custom operator arg types.
65    kwargs: the original custom operator args.
66    arg_to_base_index: maps mutable_arg_name to int | [int] that refers to the base tensor that
67                       corresponds to the input tensor
68    """
69
70    def write_single_view(prefix: str, tensor: Tensor, base_index: int):
71        assert f"{prefix}_base_index" not in kwargs
72        assert f"{prefix}_size" not in kwargs
73        assert f"{prefix}_stride" not in kwargs
74        assert f"{prefix}_storage_offset" not in kwargs
75
76        if tensor is None:
77            kwargs[f"{prefix}_base_index"] = None
78        elif get_base(tensor) is None:
79            # if the tensor is the base (not view), for simplicity we do not serialize view meta.
80            kwargs[f"{prefix}_base_index"] = base_index
81        else:
82            kwargs[f"{prefix}_base_index"] = base_index
83            kwargs[f"{prefix}_size"] = tensor.size()
84            kwargs[f"{prefix}_stride"] = tensor.stride()
85            kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset()
86
87    for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
88        arg = kwargs[arg_name]
89        if isinstance(arg_type, torch.ListType):
90            if arg is None:
91                kwargs[f"_{arg_name}_length"] = None
92
93            kwargs[f"_{arg_name}_length"] = len(arg)
94            for i, elem in enumerate(arg):
95                write_single_view(
96                    f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i]
97                )
98
99        elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)):
100            write_single_view(
101                f"_{arg_name}",
102                kwargs[arg_name],
103                arg_to_base_index.get(arg_name, None),
104            )
105        else:
106            raise RuntimeError(f"Unsupported type {arg_type}")
107
108
109# Returns a dict of arg_name -> ViewInfo | [ViewInfo]
110def read_view_information_from_args(
111    mutable_arg_names: List[str],
112    mutable_arg_types: List[torch.Type],
113    kwargs: Dict[str, Any],
114    all_bases: List[Tensor],
115):
116    """
117    This reads the view information added by `write_view_information_to_args` from kwargs, pop them,
118    and returns a dict arg_name -> ViewInfo | [ViewInfo](if the input is list). that maps each mutable arg
119    to its view information.
120    mutable_arg_names: mutable custom operator arg names.
121    mutable_arg_types: mutable custom operator arg types.
122    kwargs : args of auto_functionalize(custom_op, kwargs)
123    """
124
125    def get_arg(name):
126        return kwargs.pop(name)
127
128    def read_single_view(prefix):
129        base_index = get_arg(f"{prefix}_base_index")
130        if base_index is None:
131            return None
132        elif f"{prefix}_size" not in kwargs:
133            assert f"{prefix}_stride" not in kwargs
134            assert f"{prefix}_storage_offset" not in kwargs
135
136            # This means that the argument is the base tensor
137            return ViewInfo(base_index, all_bases[base_index], is_view=False)
138
139        else:
140            size = get_arg(f"{prefix}_size")
141            stride = get_arg(f"{prefix}_stride")
142            storage_offset = get_arg(f"{prefix}_storage_offset")
143            return ViewInfo(base_index, size, stride, storage_offset, is_view=True)
144
145    args_view_info: Dict[str, Any] = {}
146    for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types):
147        if isinstance(arg_type, torch.ListType):
148            length = get_arg(f"_{arg_name}_length")
149            if length is None:
150                # The whole list is None.
151                args_view_info[arg_name] = None
152            else:
153                args_view_info[arg_name] = [
154                    read_single_view(f"_{arg_name}_{i}") for i in range(length)
155                ]
156
157        elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)):
158            args_view_info[arg_name] = read_single_view(f"_{arg_name}")
159        else:
160            raise RuntimeError(f"Unsupported type {arg_type}")
161    return args_view_info
162
163
164# NOTE: [auto-functionalizing custom ops]
165# Users may wish to torch.compile custom ops that mutate their inputs.
166# torch.compile will automatically support this op without anyone needing
167# to provide a functionalization kernel for it. Here's how.
168#
169# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> ()
170# op. First, when FakeTensor sees this op:
171# - If the schema says it returns nothing, we can generate a trivial
172#   FakeTensor rule for it (that returns nothing).
173# - Otherwise, the user needs to provide a FakeTensor impl (fake impl)
174#
175# Next, when Python FunctionalTensor sees the op, it will functionalize
176# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...})
177# HOP and replacing the mutated inputs with corresponding outputs of this HOP.
178# This HOP effectively runs the functional version of the op when
179# called: it clones inputs that will be mutated, runs the op, and
180# then returns (output, Tensors with the new values)
181#
182# auto_functionalize_v2 is an improved version of auto_functionalize that better handle
183# re-inplacing views.
184
185
186class AutoFunctionalized(HigherOrderOperator):
187    """auto_functionalized(_mutable_op, **kwargs)
188
189    This HOP runs a "functional" version of _mutable_op.
190
191    Concretely, it looks at all the arguments that are mutable through
192    _mutable_op's operator schema, clones those kwargs, runs
193    `out = _mutable_op(**kwargs)` with the cloned values, and then returns the
194    operator output concatenated with the cloned values that were mutated.
195
196    We have some restrictions on `_mutable_op`.
197    See `can_auto_functionalize` for the restrictions. We can likely lift
198    many of these if users request it.
199
200    The reason why _mutable_op is prefixed with an
201    underscore is to prevent collisions with kwarg names in **kwargs.
202    """
203
204    def __init__(self) -> None:
205        super().__init__("auto_functionalized")
206
207    def __call__(
208        self,
209        /,
210        _mutable_op: OpOverload,
211        **kwargs: Any,
212    ) -> Tuple[Any, Tuple[Tensor, ...]]:
213        assert can_auto_functionalize(_mutable_op)
214        assert isinstance(kwargs, dict)
215        return super().__call__(_mutable_op, **kwargs)
216
217
218auto_functionalized = AutoFunctionalized()
219auto_functionalized.__module__ = "torch.ops.higher_order"
220
221auto_functionalized.fallthrough(DispatchKey.AutogradCPU)
222auto_functionalized.fallthrough(DispatchKey.AutogradCUDA)
223
224
225class AutoFunctionalizedV2(HigherOrderOperator):
226    """auto_functionalized_v2(_mutable_op, **kwargs)
227
228    This HOP runs a "functional" version of _mutable_op.
229    Unlike AutoFunctionalized, this version is improved to better handle
230    view tensors. This version is only used in non export mode.
231    """
232
233    def __init__(self) -> None:
234        super().__init__("auto_functionalized_v2")
235
236    def __call__(
237        self,
238        /,
239        _mutable_op: OpOverload,
240        **kwargs: Any,
241    ) -> Tuple[Any, Tuple[Tensor, ...]]:
242        assert can_auto_functionalize(_mutable_op)
243        assert isinstance(kwargs, dict)
244        return super().__call__(_mutable_op, **kwargs)
245
246
247auto_functionalized_v2 = AutoFunctionalizedV2()
248auto_functionalized_v2.__module__ = "torch.ops.higher_order"
249
250auto_functionalized_v2.fallthrough(DispatchKey.AutogradCPU)
251auto_functionalized_v2.fallthrough(DispatchKey.AutogradCUDA)
252
253
254def can_auto_functionalize(op: OperatorBase) -> bool:
255    if not isinstance(op, OpOverload):
256        return False
257
258    if torch._library.utils.is_builtin(op):
259        # We control the built-ins. These may (in rare cases)
260        # do input metadata mutation (which we have banned on custom ops)
261        return False
262    schema = op._schema
263    if not schema.is_mutable:
264        return False
265    schema = op._schema
266
267    for arg in schema.arguments:
268        if arg.alias_info is None:
269            continue
270        if not arg.alias_info.is_write:
271            continue
272        if type(arg.type) is torch.TensorType:
273            continue
274        if (
275            type(arg.type) is torch.OptionalType
276            and type(arg.type.getElementType()) is torch.TensorType
277        ):
278            continue
279        if (
280            type(arg.type) is torch.ListType
281            and type(arg.type.getElementType()) is torch.TensorType
282        ):
283            continue
284        # Not yet supported: other Tensor types. This includes things like
285        # Tensor?[], Tensor[]?.
286        return False
287
288    if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType):
289        # Skip schema returns -> None
290        return True
291    # The returns must not alias anything
292    for ret in schema.returns:
293        if ret.alias_info is None and type(ret.type) is torch.TensorType:
294            continue
295        # Not yet supported: List[Tensor] return.
296        return False
297    if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Functionalize"):
298        return False
299    return True
300
301
302def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]:
303    """
304    Returns the list of argument names that get mutated according to the
305    schema and their types.
306    """
307    mutable_args_names = [
308        arg.name
309        for arg in op._schema.arguments
310        if arg.alias_info is not None and arg.alias_info.is_write
311    ]
312
313    mutable_args_types = [
314        arg.type
315        for arg in op._schema.arguments
316        if arg.alias_info is not None and arg.alias_info.is_write
317    ]
318    return mutable_args_names, mutable_args_types
319
320
321def do_auto_functionalize(
322    op: OpOverload,
323    args: Tuple[Any, ...],
324    kwargs: Dict[str, Any],
325) -> Any:
326    """Functionalizes a call to op(*args, **kwargs) by emitting a call to
327    `outs = auto_functionalized(op, normalized_kwargs)`
328    and replacing the mutated (args, kwargs) with the corresponding outputs.
329
330    The normalized_kwargs are just the (args, kwargs), but all in kwarg form.
331    This makes handling easier for the auto_functionalized HOP.
332    """
333    from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
334
335    ctx = PythonFunctionalizeAPI()
336
337    # All of the (args, kwargs), but all as kwargs. The names for the
338    # args come from the schema. This makes it easier for us to work with them.
339    normalized_kwargs = {}
340    schema = op._schema
341    for idx, arg in enumerate(schema.arguments):
342        # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
343        if arg.name in kwargs:
344            normalized_kwargs[arg.name] = kwargs[arg.name]
345        elif idx < len(args):
346            # if its out of bounds we don't need to do anything
347            # as it means the the optional arg was passed with its default
348            # value
349            normalized_kwargs[arg.name] = args[idx]
350        else:
351            normalized_kwargs[arg.name] = arg.default_value
352
353    unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs)  # type: ignore[arg-type]
354    if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
355        warnings.warn(
356            "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
357            "Please consider using a different name for this argument to avoid potential issues."
358        )
359    with ctx.redispatch_to_next():
360        unwrapped_outs = auto_functionalized(
361            op, **unwrapped_kwargs  # type: ignore[arg-type]
362        )
363
364    # List of the name of args that get mutated (according to the schema)
365    mutable_args_names, _ = get_mutable_args(op)
366
367    unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[
368        : -len(mutable_args_names)
369    ]
370    unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :]
371
372    if len(op._schema.returns) == 0:
373        assert unwrapped_actual_out[0] is None
374        unwrapped_actual_out = None
375    elif len(op._schema.returns) == 1:
376        assert len(unwrapped_actual_out) == 1
377        unwrapped_actual_out = unwrapped_actual_out[0]
378    else:
379        assert len(unwrapped_actual_out) == len(op._schema.returns)
380
381    for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out):
382        # Can be None if input was `Tensor(a!)?`
383        if unwrapped_out is None:
384            continue
385
386        # We only handle Tensor or List[Tensor] here for now.
387        def sync_update(o, orig_arg):
388            ctx.replace(orig_arg, o)
389            ctx.commit_update(orig_arg)
390            ctx.sync(orig_arg)
391
392        orig_arg = normalized_kwargs[name]
393
394        if isinstance(unwrapped_out, torch.Tensor):
395            sync_update(unwrapped_out, orig_arg)
396        elif isinstance(unwrapped_out, list) and all(
397            isinstance(o, torch.Tensor) for o in unwrapped_out
398        ):
399            assert len(orig_arg) == len(unwrapped_out)
400            for orig_a, o in zip(orig_arg, unwrapped_out):
401                sync_update(o, orig_a)
402        else:
403            raise RuntimeError(
404                f"unsupported type for auto-functionalization: {unwrapped_out}"
405            )
406
407    return ctx.wrap_tensors(unwrapped_actual_out)  # type: ignore[arg-type]
408
409
410def do_auto_functionalize_v2(
411    op: OpOverload,
412    args: Tuple[Any, ...],
413    kwargs: Dict[str, Any],
414) -> Any:
415    from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
416
417    ctx = PythonFunctionalizeAPI()
418
419    # All of the (args, kwargs), but all as kwargs. The names for the
420    # args come from the schema. This makes it easier for us to work with them.
421    normalized_kwargs = {}
422
423    schema = op._schema
424    for idx, arg in enumerate(schema.arguments):
425        # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema
426        if arg.name in kwargs:
427            normalized_kwargs[arg.name] = kwargs[arg.name]
428        elif idx < len(args):
429            # if its out of bounds we don't need to do anything
430            # as it means the the optional arg was passed with its default
431            # value
432            normalized_kwargs[arg.name] = args[idx]
433        else:
434            normalized_kwargs[arg.name] = arg.default_value
435
436    # List of the name of args that get mutated (according to the schema)
437    mutable_args_names, mutable_args_types = get_mutable_args(op)
438
439    # A list of all bases of mutable args without duplication
440    all_bases = []
441    all_bases_addresses: list[int] = []
442
443    # Map arg_name to the index of its base in all_bases.
444    arg_to_base_index: Dict[str, Any] = {}
445
446    def update_dict(tensor, arg_name, index=None):
447        base = tensor if get_base(tensor) is None else get_base(tensor)
448
449        def set_result(base_index):
450            if index is None:
451                arg_to_base_index[arg_name] = base_index
452            else:
453                arg_to_base_index[arg_name][index] = base_index
454
455        if not all_bases_addresses.__contains__(base._cdata):
456            all_bases_addresses.append(base._cdata)
457            all_bases.append(base)
458            set_result(len(all_bases) - 1)
459        else:
460            set_result(all_bases_addresses.index(base._cdata))
461
462    for arg_name in mutable_args_names:
463        arg = normalized_kwargs[arg_name]
464        if arg is None:
465            continue
466
467        if isinstance(arg, list):
468            arg_to_base_index[arg_name] = {}
469            for i, tensor in enumerate(arg):
470                if tensor is None:
471                    arg_to_base_index[arg_name].append(None)
472                    continue
473
474                update_dict(tensor, arg_name, i)
475
476        else:
477            update_dict(arg, arg_name)
478
479    # add view_meta for each args into unwrapped_kwargs.
480    write_view_information_to_args(
481        mutable_args_names,
482        mutable_args_types,
483        normalized_kwargs,
484        arg_to_base_index,
485    )
486
487    # remove mutated args from the kwargs (its a function of _all_bases now)
488    for arg_name in mutable_args_names:
489        del normalized_kwargs[arg_name]  # type: ignore[arg-type]
490
491    unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs)  # type: ignore[arg-type]
492    if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs:
493        warnings.warn(
494            "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. "
495            "Please consider using a different name for this argument to avoid potential issues."
496        )
497    all_basis_unwrapped = ctx.unwrap_tensors(all_bases)
498
499    with ctx.redispatch_to_next():
500        unwrapped_outs = auto_functionalized_v2(
501            op, **dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped)  # type: ignore[arg-type]
502        )
503
504    unwrapped_actual_out: Union[Any, Tuple[Any]] = (
505        unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)]
506    )
507
508    unwrapped_mutable_out = (
509        [] if len(all_bases) == 0 else unwrapped_outs[-len(all_bases) :]
510    )
511
512    if len(op._schema.returns) == 0:
513        assert unwrapped_actual_out[0] is None
514        unwrapped_actual_out = None
515    elif len(op._schema.returns) == 1:
516        assert len(unwrapped_actual_out) == 1
517        unwrapped_actual_out = unwrapped_actual_out[0]
518    else:
519        assert len(unwrapped_actual_out) == len(op._schema.returns)
520
521    for orig_arg, unwrapped_out in zip(all_bases, unwrapped_mutable_out):
522        # Can be None if input was `Tensor(a!)?`
523        if unwrapped_out is None:
524            continue
525
526        # We only handle Tensor or List[Tensor] here for now.
527        def sync_update(o, orig_arg):
528            ctx.replace(orig_arg, o)
529            ctx.commit_update(orig_arg)
530            ctx.sync(orig_arg)
531
532        if isinstance(unwrapped_out, torch.Tensor):
533            sync_update(unwrapped_out, orig_arg)
534        elif isinstance(unwrapped_out, list) and all(
535            isinstance(o, torch.Tensor) for o in unwrapped_out
536        ):
537            assert len(orig_arg) == len(unwrapped_out)
538            for orig_a, o in zip(orig_arg, unwrapped_out):
539                sync_update(o, orig_a)
540        else:
541            raise RuntimeError(
542                f"unsupported type for auto-functionalization: {unwrapped_out}"
543            )
544
545    return ctx.wrap_tensors(unwrapped_actual_out)  # type: ignore[arg-type]
546
547
548# auto_functionalize functions
549@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd)
550def auto_functionalized_dense(
551    _mutable_op: OpOverload,
552    _only_clone_these_tensors: Optional[Tuple[str, ...]] = None,
553    **kwargs: Any,
554) -> Tuple[Any, Tuple[Tensor, ...]]:
555    new_kwargs = dict(**kwargs)
556    result = []
557
558    _mutable_args_names, _ = get_mutable_args(_mutable_op)
559    for name in _mutable_args_names:
560        if (
561            _only_clone_these_tensors is not None
562            and name not in _only_clone_these_tensors
563        ):
564            new_kwargs[name] = kwargs[name]
565        else:
566            new_kwargs[name] = (
567                [clone_preserve_strides(x) for x in kwargs[name]]
568                if kwargs[name] is not None and isinstance(kwargs[name], list)
569                else clone_preserve_strides(kwargs[name])
570                if kwargs[name] is not None
571                else None
572            )
573        result.append(new_kwargs[name])
574    out = _mutable_op(**new_kwargs)
575
576    if isinstance(out, tuple):
577        return (*out, *result)  # type: ignore[return-value]
578    else:
579        return (out, *result)  # type: ignore[return-value]
580
581
582@auto_functionalized.py_impl(FakeTensorMode)
583def auto_functionalized_fake(
584    mode,
585    _mutable_op: OpOverload,
586    **kwargs: Any,
587) -> Tuple[Any, Tuple[Tensor, ...]]:
588    with mode:
589        result = auto_functionalized_dense(_mutable_op, **kwargs)
590        return result
591
592
593@auto_functionalized.py_impl(ProxyTorchDispatchMode)
594def auto_functionalized_proxy(
595    mode,
596    _mutable_op: OpOverload,
597    **kwargs: Any,
598) -> Tuple[Any, Tuple[Tensor, ...]]:
599    with disable_proxy_modes_tracing():
600        out = auto_functionalized(_mutable_op, **kwargs)
601
602    proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
603    out_proxy = mode.tracer.create_proxy(
604        "call_function",
605        auto_functionalized,
606        (_mutable_op,),
607        proxy_kwargs,
608    )
609    result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
610    return result
611
612
613@auto_functionalized.py_functionalize_impl
614def auto_functionalized_func(ctx, _mutable_op, **kwargs):
615    unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
616    with ctx.redispatch_to_next():
617        result = auto_functionalized(_mutable_op, **unwrapped_kwargs)
618    return ctx.wrap_tensors(result)
619
620
621# auto_functionalized_v2 functions
622@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd)
623def auto_functionalized_v2_dense(
624    _mutable_op: OpOverload,
625    _only_clone_these_bases: Optional[Tuple[int, ...]] = None,
626    **kwargs: Any,
627) -> Tuple[Any, Tuple[Tensor, ...]]:
628    all_bases: List[Tensor] = kwargs.pop("_all_bases", [])
629    mutable_args_names, mutable_args_types = get_mutable_args(_mutable_op)
630    args_view_info = read_view_information_from_args(
631        mutable_args_names, mutable_args_types, kwargs, all_bases
632    )
633
634    if _only_clone_these_bases is None:
635        _only_clone_these_bases = tuple(range(len(all_bases)))
636
637    def maybe_copy(i, t):
638        if t is None:
639            return None
640        if i in _only_clone_these_bases:
641            return clone_preserve_strides(t)
642        else:
643            return t
644
645    all_bases_new = [maybe_copy(i, t) for i, t in enumerate(all_bases)]
646
647    # create new args
648    new_kwargs = dict(**kwargs)
649
650    # re-generate all inputs from all_bases_new using args_view_info and add them to new_kwargs.
651    for arg_name in mutable_args_names:
652        if args_view_info[arg_name] is None:
653            new_kwargs[arg_name] = None
654        elif isinstance(args_view_info[arg_name], list):
655            new_kwargs[arg_name] = []
656            for i, elem in enumerate(args_view_info[arg_name]):
657                if elem is None:
658                    new_kwargs[arg_name].append(None)
659                else:
660                    view_info = args_view_info[arg_name][i]
661                    new_kwargs[arg_name].append(
662                        view_info.regenerate_view(all_bases_new)
663                    )
664        else:
665            new_kwargs[arg_name] = args_view_info[arg_name].regenerate_view(
666                all_bases_new
667            )
668
669    out = _mutable_op(**new_kwargs)
670
671    if isinstance(out, tuple):
672        return (*out, *all_bases_new)  # type: ignore[return-value]
673    else:
674        return (out, *all_bases_new)  # type: ignore[return-value]
675
676
677@auto_functionalized_v2.py_impl(FakeTensorMode)
678def auto_functionalized_v2_fake(
679    mode,
680    _mutable_op: OpOverload,
681    **kwargs: Dict[str, Any],
682) -> Tuple[Any, Tuple[Tensor, ...]]:
683    with mode:
684        result = auto_functionalized_v2_dense(_mutable_op, **kwargs)
685        return result
686
687
688@auto_functionalized_v2.py_impl(ProxyTorchDispatchMode)
689def auto_functionalized_v2_proxy(
690    mode,
691    _mutable_op: OpOverload,
692    **kwargs: Dict[str, Any],
693) -> Tuple[Any, Tuple[Tensor, ...]]:
694    with disable_proxy_modes_tracing():
695        out = auto_functionalized_v2(_mutable_op, **kwargs)
696
697    proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
698    out_proxy = mode.tracer.create_proxy(
699        "call_function",
700        auto_functionalized_v2,
701        (_mutable_op,),
702        proxy_kwargs,
703    )
704    result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
705    return result
706
707
708@auto_functionalized_v2.py_functionalize_impl
709def auto_functionalized_v2_func(ctx, _mutable_op, **kwargs):
710    unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
711    with ctx.redispatch_to_next():
712        result = auto_functionalized_v2(_mutable_op, **unwrapped_kwargs)
713    return ctx.wrap_tensors(result)
714