1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import warnings 4from dataclasses import dataclass 5from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 6 7import torch 8import torch.utils._pytree as pytree 9from torch import Tensor 10from torch._C import DispatchKey 11from torch._ops import HigherOrderOperator, OperatorBase, OpOverload 12from torch._prims_common import clone_preserve_strides 13from torch._subclasses.fake_tensor import FakeTensorMode 14from torch.fx.experimental.proxy_tensor import ( 15 disable_proxy_modes_tracing, 16 ProxyTorchDispatchMode, 17 track_tensor_tree, 18) 19 20 21def get_base(tensor): 22 if torch.is_inference_mode_enabled(): 23 return tensor._inference_mode_base 24 else: 25 return tensor._base 26 27 28@dataclass 29class ViewInfo: 30 base_index: int 31 size: Optional[Sequence[Union[int, torch.SymInt]]] = None 32 stride: Optional[Sequence[Union[int, torch.SymInt]]] = None 33 storage_offset: Optional[int] = None 34 # When is_view is false, the tensor is the base, and 35 # size, stride and storage_offset are all None. 36 is_view: bool = True 37 38 def regenerate_view(self, bases_list: List[Tensor]): 39 if not self.is_view: 40 return bases_list[self.base_index] 41 42 assert self.stride is not None 43 assert self.size is not None 44 assert self.storage_offset is not None 45 46 return torch.as_strided( 47 bases_list[self.base_index], 48 self.size, 49 self.stride, 50 self.storage_offset, 51 ) 52 53 54def write_view_information_to_args( 55 mutable_arg_names: List[str], 56 mutable_arg_types: List[torch.Type], 57 kwargs: Dict[str, Any], 58 arg_to_base_index: Dict[str, Any], 59): 60 """ 61 This function writes the view information into kwargs. It reads mutable_args from kwargs. 62 and uses arg_to_base_index and tensor information to write ViewInfo into kwargs. 63 mutable_arg_names: mutable custom operator arg names. 64 mutable_arg_types: mutable custom operator arg types. 65 kwargs: the original custom operator args. 66 arg_to_base_index: maps mutable_arg_name to int | [int] that refers to the base tensor that 67 corresponds to the input tensor 68 """ 69 70 def write_single_view(prefix: str, tensor: Tensor, base_index: int): 71 assert f"{prefix}_base_index" not in kwargs 72 assert f"{prefix}_size" not in kwargs 73 assert f"{prefix}_stride" not in kwargs 74 assert f"{prefix}_storage_offset" not in kwargs 75 76 if tensor is None: 77 kwargs[f"{prefix}_base_index"] = None 78 elif get_base(tensor) is None: 79 # if the tensor is the base (not view), for simplicity we do not serialize view meta. 80 kwargs[f"{prefix}_base_index"] = base_index 81 else: 82 kwargs[f"{prefix}_base_index"] = base_index 83 kwargs[f"{prefix}_size"] = tensor.size() 84 kwargs[f"{prefix}_stride"] = tensor.stride() 85 kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset() 86 87 for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): 88 arg = kwargs[arg_name] 89 if isinstance(arg_type, torch.ListType): 90 if arg is None: 91 kwargs[f"_{arg_name}_length"] = None 92 93 kwargs[f"_{arg_name}_length"] = len(arg) 94 for i, elem in enumerate(arg): 95 write_single_view( 96 f"_{arg_name}_{i}", elem, arg_to_base_index[arg_name][i] 97 ) 98 99 elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): 100 write_single_view( 101 f"_{arg_name}", 102 kwargs[arg_name], 103 arg_to_base_index.get(arg_name, None), 104 ) 105 else: 106 raise RuntimeError(f"Unsupported type {arg_type}") 107 108 109# Returns a dict of arg_name -> ViewInfo | [ViewInfo] 110def read_view_information_from_args( 111 mutable_arg_names: List[str], 112 mutable_arg_types: List[torch.Type], 113 kwargs: Dict[str, Any], 114 all_bases: List[Tensor], 115): 116 """ 117 This reads the view information added by `write_view_information_to_args` from kwargs, pop them, 118 and returns a dict arg_name -> ViewInfo | [ViewInfo](if the input is list). that maps each mutable arg 119 to its view information. 120 mutable_arg_names: mutable custom operator arg names. 121 mutable_arg_types: mutable custom operator arg types. 122 kwargs : args of auto_functionalize(custom_op, kwargs) 123 """ 124 125 def get_arg(name): 126 return kwargs.pop(name) 127 128 def read_single_view(prefix): 129 base_index = get_arg(f"{prefix}_base_index") 130 if base_index is None: 131 return None 132 elif f"{prefix}_size" not in kwargs: 133 assert f"{prefix}_stride" not in kwargs 134 assert f"{prefix}_storage_offset" not in kwargs 135 136 # This means that the argument is the base tensor 137 return ViewInfo(base_index, all_bases[base_index], is_view=False) 138 139 else: 140 size = get_arg(f"{prefix}_size") 141 stride = get_arg(f"{prefix}_stride") 142 storage_offset = get_arg(f"{prefix}_storage_offset") 143 return ViewInfo(base_index, size, stride, storage_offset, is_view=True) 144 145 args_view_info: Dict[str, Any] = {} 146 for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): 147 if isinstance(arg_type, torch.ListType): 148 length = get_arg(f"_{arg_name}_length") 149 if length is None: 150 # The whole list is None. 151 args_view_info[arg_name] = None 152 else: 153 args_view_info[arg_name] = [ 154 read_single_view(f"_{arg_name}_{i}") for i in range(length) 155 ] 156 157 elif isinstance(arg_type, (torch.TensorType, torch.OptionalType)): 158 args_view_info[arg_name] = read_single_view(f"_{arg_name}") 159 else: 160 raise RuntimeError(f"Unsupported type {arg_type}") 161 return args_view_info 162 163 164# NOTE: [auto-functionalizing custom ops] 165# Users may wish to torch.compile custom ops that mutate their inputs. 166# torch.compile will automatically support this op without anyone needing 167# to provide a functionalization kernel for it. Here's how. 168# 169# Let's say we have a hypothetical mylib::sin_(Tensor(a!) x) -> () 170# op. First, when FakeTensor sees this op: 171# - If the schema says it returns nothing, we can generate a trivial 172# FakeTensor rule for it (that returns nothing). 173# - Otherwise, the user needs to provide a FakeTensor impl (fake impl) 174# 175# Next, when Python FunctionalTensor sees the op, it will functionalize 176# it by emitting a call to an auto_functionalize(op, ["x"], {"x": ...}) 177# HOP and replacing the mutated inputs with corresponding outputs of this HOP. 178# This HOP effectively runs the functional version of the op when 179# called: it clones inputs that will be mutated, runs the op, and 180# then returns (output, Tensors with the new values) 181# 182# auto_functionalize_v2 is an improved version of auto_functionalize that better handle 183# re-inplacing views. 184 185 186class AutoFunctionalized(HigherOrderOperator): 187 """auto_functionalized(_mutable_op, **kwargs) 188 189 This HOP runs a "functional" version of _mutable_op. 190 191 Concretely, it looks at all the arguments that are mutable through 192 _mutable_op's operator schema, clones those kwargs, runs 193 `out = _mutable_op(**kwargs)` with the cloned values, and then returns the 194 operator output concatenated with the cloned values that were mutated. 195 196 We have some restrictions on `_mutable_op`. 197 See `can_auto_functionalize` for the restrictions. We can likely lift 198 many of these if users request it. 199 200 The reason why _mutable_op is prefixed with an 201 underscore is to prevent collisions with kwarg names in **kwargs. 202 """ 203 204 def __init__(self) -> None: 205 super().__init__("auto_functionalized") 206 207 def __call__( 208 self, 209 /, 210 _mutable_op: OpOverload, 211 **kwargs: Any, 212 ) -> Tuple[Any, Tuple[Tensor, ...]]: 213 assert can_auto_functionalize(_mutable_op) 214 assert isinstance(kwargs, dict) 215 return super().__call__(_mutable_op, **kwargs) 216 217 218auto_functionalized = AutoFunctionalized() 219auto_functionalized.__module__ = "torch.ops.higher_order" 220 221auto_functionalized.fallthrough(DispatchKey.AutogradCPU) 222auto_functionalized.fallthrough(DispatchKey.AutogradCUDA) 223 224 225class AutoFunctionalizedV2(HigherOrderOperator): 226 """auto_functionalized_v2(_mutable_op, **kwargs) 227 228 This HOP runs a "functional" version of _mutable_op. 229 Unlike AutoFunctionalized, this version is improved to better handle 230 view tensors. This version is only used in non export mode. 231 """ 232 233 def __init__(self) -> None: 234 super().__init__("auto_functionalized_v2") 235 236 def __call__( 237 self, 238 /, 239 _mutable_op: OpOverload, 240 **kwargs: Any, 241 ) -> Tuple[Any, Tuple[Tensor, ...]]: 242 assert can_auto_functionalize(_mutable_op) 243 assert isinstance(kwargs, dict) 244 return super().__call__(_mutable_op, **kwargs) 245 246 247auto_functionalized_v2 = AutoFunctionalizedV2() 248auto_functionalized_v2.__module__ = "torch.ops.higher_order" 249 250auto_functionalized_v2.fallthrough(DispatchKey.AutogradCPU) 251auto_functionalized_v2.fallthrough(DispatchKey.AutogradCUDA) 252 253 254def can_auto_functionalize(op: OperatorBase) -> bool: 255 if not isinstance(op, OpOverload): 256 return False 257 258 if torch._library.utils.is_builtin(op): 259 # We control the built-ins. These may (in rare cases) 260 # do input metadata mutation (which we have banned on custom ops) 261 return False 262 schema = op._schema 263 if not schema.is_mutable: 264 return False 265 schema = op._schema 266 267 for arg in schema.arguments: 268 if arg.alias_info is None: 269 continue 270 if not arg.alias_info.is_write: 271 continue 272 if type(arg.type) is torch.TensorType: 273 continue 274 if ( 275 type(arg.type) is torch.OptionalType 276 and type(arg.type.getElementType()) is torch.TensorType 277 ): 278 continue 279 if ( 280 type(arg.type) is torch.ListType 281 and type(arg.type.getElementType()) is torch.TensorType 282 ): 283 continue 284 # Not yet supported: other Tensor types. This includes things like 285 # Tensor?[], Tensor[]?. 286 return False 287 288 if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType): 289 # Skip schema returns -> None 290 return True 291 # The returns must not alias anything 292 for ret in schema.returns: 293 if ret.alias_info is None and type(ret.type) is torch.TensorType: 294 continue 295 # Not yet supported: List[Tensor] return. 296 return False 297 if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Functionalize"): 298 return False 299 return True 300 301 302def get_mutable_args(op: OpOverload) -> Tuple[List[str], List[torch.Type]]: 303 """ 304 Returns the list of argument names that get mutated according to the 305 schema and their types. 306 """ 307 mutable_args_names = [ 308 arg.name 309 for arg in op._schema.arguments 310 if arg.alias_info is not None and arg.alias_info.is_write 311 ] 312 313 mutable_args_types = [ 314 arg.type 315 for arg in op._schema.arguments 316 if arg.alias_info is not None and arg.alias_info.is_write 317 ] 318 return mutable_args_names, mutable_args_types 319 320 321def do_auto_functionalize( 322 op: OpOverload, 323 args: Tuple[Any, ...], 324 kwargs: Dict[str, Any], 325) -> Any: 326 """Functionalizes a call to op(*args, **kwargs) by emitting a call to 327 `outs = auto_functionalized(op, normalized_kwargs)` 328 and replacing the mutated (args, kwargs) with the corresponding outputs. 329 330 The normalized_kwargs are just the (args, kwargs), but all in kwarg form. 331 This makes handling easier for the auto_functionalized HOP. 332 """ 333 from torch._subclasses.functional_tensor import PythonFunctionalizeAPI 334 335 ctx = PythonFunctionalizeAPI() 336 337 # All of the (args, kwargs), but all as kwargs. The names for the 338 # args come from the schema. This makes it easier for us to work with them. 339 normalized_kwargs = {} 340 schema = op._schema 341 for idx, arg in enumerate(schema.arguments): 342 # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema 343 if arg.name in kwargs: 344 normalized_kwargs[arg.name] = kwargs[arg.name] 345 elif idx < len(args): 346 # if its out of bounds we don't need to do anything 347 # as it means the the optional arg was passed with its default 348 # value 349 normalized_kwargs[arg.name] = args[idx] 350 else: 351 normalized_kwargs[arg.name] = arg.default_value 352 353 unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] 354 if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: 355 warnings.warn( 356 "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " 357 "Please consider using a different name for this argument to avoid potential issues." 358 ) 359 with ctx.redispatch_to_next(): 360 unwrapped_outs = auto_functionalized( 361 op, **unwrapped_kwargs # type: ignore[arg-type] 362 ) 363 364 # List of the name of args that get mutated (according to the schema) 365 mutable_args_names, _ = get_mutable_args(op) 366 367 unwrapped_actual_out: Union[Any, Tuple[Any]] = unwrapped_outs[ 368 : -len(mutable_args_names) 369 ] 370 unwrapped_mutable_out = unwrapped_outs[-len(mutable_args_names) :] 371 372 if len(op._schema.returns) == 0: 373 assert unwrapped_actual_out[0] is None 374 unwrapped_actual_out = None 375 elif len(op._schema.returns) == 1: 376 assert len(unwrapped_actual_out) == 1 377 unwrapped_actual_out = unwrapped_actual_out[0] 378 else: 379 assert len(unwrapped_actual_out) == len(op._schema.returns) 380 381 for name, unwrapped_out in zip(mutable_args_names, unwrapped_mutable_out): 382 # Can be None if input was `Tensor(a!)?` 383 if unwrapped_out is None: 384 continue 385 386 # We only handle Tensor or List[Tensor] here for now. 387 def sync_update(o, orig_arg): 388 ctx.replace(orig_arg, o) 389 ctx.commit_update(orig_arg) 390 ctx.sync(orig_arg) 391 392 orig_arg = normalized_kwargs[name] 393 394 if isinstance(unwrapped_out, torch.Tensor): 395 sync_update(unwrapped_out, orig_arg) 396 elif isinstance(unwrapped_out, list) and all( 397 isinstance(o, torch.Tensor) for o in unwrapped_out 398 ): 399 assert len(orig_arg) == len(unwrapped_out) 400 for orig_a, o in zip(orig_arg, unwrapped_out): 401 sync_update(o, orig_a) 402 else: 403 raise RuntimeError( 404 f"unsupported type for auto-functionalization: {unwrapped_out}" 405 ) 406 407 return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] 408 409 410def do_auto_functionalize_v2( 411 op: OpOverload, 412 args: Tuple[Any, ...], 413 kwargs: Dict[str, Any], 414) -> Any: 415 from torch._subclasses.functional_tensor import PythonFunctionalizeAPI 416 417 ctx = PythonFunctionalizeAPI() 418 419 # All of the (args, kwargs), but all as kwargs. The names for the 420 # args come from the schema. This makes it easier for us to work with them. 421 normalized_kwargs = {} 422 423 schema = op._schema 424 for idx, arg in enumerate(schema.arguments): 425 # NB: torch_dispatch kwargs are the args defined as kwarg-only in the schema 426 if arg.name in kwargs: 427 normalized_kwargs[arg.name] = kwargs[arg.name] 428 elif idx < len(args): 429 # if its out of bounds we don't need to do anything 430 # as it means the the optional arg was passed with its default 431 # value 432 normalized_kwargs[arg.name] = args[idx] 433 else: 434 normalized_kwargs[arg.name] = arg.default_value 435 436 # List of the name of args that get mutated (according to the schema) 437 mutable_args_names, mutable_args_types = get_mutable_args(op) 438 439 # A list of all bases of mutable args without duplication 440 all_bases = [] 441 all_bases_addresses: list[int] = [] 442 443 # Map arg_name to the index of its base in all_bases. 444 arg_to_base_index: Dict[str, Any] = {} 445 446 def update_dict(tensor, arg_name, index=None): 447 base = tensor if get_base(tensor) is None else get_base(tensor) 448 449 def set_result(base_index): 450 if index is None: 451 arg_to_base_index[arg_name] = base_index 452 else: 453 arg_to_base_index[arg_name][index] = base_index 454 455 if not all_bases_addresses.__contains__(base._cdata): 456 all_bases_addresses.append(base._cdata) 457 all_bases.append(base) 458 set_result(len(all_bases) - 1) 459 else: 460 set_result(all_bases_addresses.index(base._cdata)) 461 462 for arg_name in mutable_args_names: 463 arg = normalized_kwargs[arg_name] 464 if arg is None: 465 continue 466 467 if isinstance(arg, list): 468 arg_to_base_index[arg_name] = {} 469 for i, tensor in enumerate(arg): 470 if tensor is None: 471 arg_to_base_index[arg_name].append(None) 472 continue 473 474 update_dict(tensor, arg_name, i) 475 476 else: 477 update_dict(arg, arg_name) 478 479 # add view_meta for each args into unwrapped_kwargs. 480 write_view_information_to_args( 481 mutable_args_names, 482 mutable_args_types, 483 normalized_kwargs, 484 arg_to_base_index, 485 ) 486 487 # remove mutated args from the kwargs (its a function of _all_bases now) 488 for arg_name in mutable_args_names: 489 del normalized_kwargs[arg_name] # type: ignore[arg-type] 490 491 unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] 492 if "self" in unwrapped_kwargs or "self_" in unwrapped_kwargs: 493 warnings.warn( 494 "Using `self` or `self_` as an argument in the definition of custom ops may lead to ambiguous parsing. " 495 "Please consider using a different name for this argument to avoid potential issues." 496 ) 497 all_basis_unwrapped = ctx.unwrap_tensors(all_bases) 498 499 with ctx.redispatch_to_next(): 500 unwrapped_outs = auto_functionalized_v2( 501 op, **dict(unwrapped_kwargs, _all_bases=all_basis_unwrapped) # type: ignore[arg-type] 502 ) 503 504 unwrapped_actual_out: Union[Any, Tuple[Any]] = ( 505 unwrapped_outs if len(all_bases) == 0 else unwrapped_outs[: -len(all_bases)] 506 ) 507 508 unwrapped_mutable_out = ( 509 [] if len(all_bases) == 0 else unwrapped_outs[-len(all_bases) :] 510 ) 511 512 if len(op._schema.returns) == 0: 513 assert unwrapped_actual_out[0] is None 514 unwrapped_actual_out = None 515 elif len(op._schema.returns) == 1: 516 assert len(unwrapped_actual_out) == 1 517 unwrapped_actual_out = unwrapped_actual_out[0] 518 else: 519 assert len(unwrapped_actual_out) == len(op._schema.returns) 520 521 for orig_arg, unwrapped_out in zip(all_bases, unwrapped_mutable_out): 522 # Can be None if input was `Tensor(a!)?` 523 if unwrapped_out is None: 524 continue 525 526 # We only handle Tensor or List[Tensor] here for now. 527 def sync_update(o, orig_arg): 528 ctx.replace(orig_arg, o) 529 ctx.commit_update(orig_arg) 530 ctx.sync(orig_arg) 531 532 if isinstance(unwrapped_out, torch.Tensor): 533 sync_update(unwrapped_out, orig_arg) 534 elif isinstance(unwrapped_out, list) and all( 535 isinstance(o, torch.Tensor) for o in unwrapped_out 536 ): 537 assert len(orig_arg) == len(unwrapped_out) 538 for orig_a, o in zip(orig_arg, unwrapped_out): 539 sync_update(o, orig_a) 540 else: 541 raise RuntimeError( 542 f"unsupported type for auto-functionalization: {unwrapped_out}" 543 ) 544 545 return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type] 546 547 548# auto_functionalize functions 549@auto_functionalized.py_impl(DispatchKey.CompositeExplicitAutograd) 550def auto_functionalized_dense( 551 _mutable_op: OpOverload, 552 _only_clone_these_tensors: Optional[Tuple[str, ...]] = None, 553 **kwargs: Any, 554) -> Tuple[Any, Tuple[Tensor, ...]]: 555 new_kwargs = dict(**kwargs) 556 result = [] 557 558 _mutable_args_names, _ = get_mutable_args(_mutable_op) 559 for name in _mutable_args_names: 560 if ( 561 _only_clone_these_tensors is not None 562 and name not in _only_clone_these_tensors 563 ): 564 new_kwargs[name] = kwargs[name] 565 else: 566 new_kwargs[name] = ( 567 [clone_preserve_strides(x) for x in kwargs[name]] 568 if kwargs[name] is not None and isinstance(kwargs[name], list) 569 else clone_preserve_strides(kwargs[name]) 570 if kwargs[name] is not None 571 else None 572 ) 573 result.append(new_kwargs[name]) 574 out = _mutable_op(**new_kwargs) 575 576 if isinstance(out, tuple): 577 return (*out, *result) # type: ignore[return-value] 578 else: 579 return (out, *result) # type: ignore[return-value] 580 581 582@auto_functionalized.py_impl(FakeTensorMode) 583def auto_functionalized_fake( 584 mode, 585 _mutable_op: OpOverload, 586 **kwargs: Any, 587) -> Tuple[Any, Tuple[Tensor, ...]]: 588 with mode: 589 result = auto_functionalized_dense(_mutable_op, **kwargs) 590 return result 591 592 593@auto_functionalized.py_impl(ProxyTorchDispatchMode) 594def auto_functionalized_proxy( 595 mode, 596 _mutable_op: OpOverload, 597 **kwargs: Any, 598) -> Tuple[Any, Tuple[Tensor, ...]]: 599 with disable_proxy_modes_tracing(): 600 out = auto_functionalized(_mutable_op, **kwargs) 601 602 proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) 603 out_proxy = mode.tracer.create_proxy( 604 "call_function", 605 auto_functionalized, 606 (_mutable_op,), 607 proxy_kwargs, 608 ) 609 result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) 610 return result 611 612 613@auto_functionalized.py_functionalize_impl 614def auto_functionalized_func(ctx, _mutable_op, **kwargs): 615 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 616 with ctx.redispatch_to_next(): 617 result = auto_functionalized(_mutable_op, **unwrapped_kwargs) 618 return ctx.wrap_tensors(result) 619 620 621# auto_functionalized_v2 functions 622@auto_functionalized_v2.py_impl(DispatchKey.CompositeExplicitAutograd) 623def auto_functionalized_v2_dense( 624 _mutable_op: OpOverload, 625 _only_clone_these_bases: Optional[Tuple[int, ...]] = None, 626 **kwargs: Any, 627) -> Tuple[Any, Tuple[Tensor, ...]]: 628 all_bases: List[Tensor] = kwargs.pop("_all_bases", []) 629 mutable_args_names, mutable_args_types = get_mutable_args(_mutable_op) 630 args_view_info = read_view_information_from_args( 631 mutable_args_names, mutable_args_types, kwargs, all_bases 632 ) 633 634 if _only_clone_these_bases is None: 635 _only_clone_these_bases = tuple(range(len(all_bases))) 636 637 def maybe_copy(i, t): 638 if t is None: 639 return None 640 if i in _only_clone_these_bases: 641 return clone_preserve_strides(t) 642 else: 643 return t 644 645 all_bases_new = [maybe_copy(i, t) for i, t in enumerate(all_bases)] 646 647 # create new args 648 new_kwargs = dict(**kwargs) 649 650 # re-generate all inputs from all_bases_new using args_view_info and add them to new_kwargs. 651 for arg_name in mutable_args_names: 652 if args_view_info[arg_name] is None: 653 new_kwargs[arg_name] = None 654 elif isinstance(args_view_info[arg_name], list): 655 new_kwargs[arg_name] = [] 656 for i, elem in enumerate(args_view_info[arg_name]): 657 if elem is None: 658 new_kwargs[arg_name].append(None) 659 else: 660 view_info = args_view_info[arg_name][i] 661 new_kwargs[arg_name].append( 662 view_info.regenerate_view(all_bases_new) 663 ) 664 else: 665 new_kwargs[arg_name] = args_view_info[arg_name].regenerate_view( 666 all_bases_new 667 ) 668 669 out = _mutable_op(**new_kwargs) 670 671 if isinstance(out, tuple): 672 return (*out, *all_bases_new) # type: ignore[return-value] 673 else: 674 return (out, *all_bases_new) # type: ignore[return-value] 675 676 677@auto_functionalized_v2.py_impl(FakeTensorMode) 678def auto_functionalized_v2_fake( 679 mode, 680 _mutable_op: OpOverload, 681 **kwargs: Dict[str, Any], 682) -> Tuple[Any, Tuple[Tensor, ...]]: 683 with mode: 684 result = auto_functionalized_v2_dense(_mutable_op, **kwargs) 685 return result 686 687 688@auto_functionalized_v2.py_impl(ProxyTorchDispatchMode) 689def auto_functionalized_v2_proxy( 690 mode, 691 _mutable_op: OpOverload, 692 **kwargs: Dict[str, Any], 693) -> Tuple[Any, Tuple[Tensor, ...]]: 694 with disable_proxy_modes_tracing(): 695 out = auto_functionalized_v2(_mutable_op, **kwargs) 696 697 proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) 698 out_proxy = mode.tracer.create_proxy( 699 "call_function", 700 auto_functionalized_v2, 701 (_mutable_op,), 702 proxy_kwargs, 703 ) 704 result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) 705 return result 706 707 708@auto_functionalized_v2.py_functionalize_impl 709def auto_functionalized_v2_func(ctx, _mutable_op, **kwargs): 710 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 711 with ctx.redispatch_to_next(): 712 result = auto_functionalized_v2(_mutable_op, **unwrapped_kwargs) 713 return ctx.wrap_tensors(result) 714