1# mypy: allow-untyped-defs 2import contextlib 3import warnings 4import weakref 5from abc import ABC, abstractmethod 6from typing import Any, Callable, ContextManager, Dict, List, Optional, Tuple, Union 7 8import torch 9import torch._inductor.config as inductor_config 10import torch.utils._pytree as pytree 11from torch._C import _functionalization_reapply_views_tls as _reapply_views 12from torch._ops import _get_dispatch_mode_pre_dispatch 13from torch._subclasses.meta_utils import is_sparse_any 14from torch.utils._python_dispatch import ( 15 _detect_infra_mode, 16 _disable_infra_mode, 17 return_and_correct_aliasing, 18 TorchDispatchMode, 19) 20 21 22not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") 23 24 25# NOTE Some special handling for tensor conversion during export is needed. 26# Normally, when tracing through the model with tensor.to(), the maybe-aliasing 27# relationship between input and output tensors will be baked into the graph. 28# For example, if we got a tensor with device cpu and call tensor.to("cpu"), 29# it will become a no-op in the graph. For a whole graph capture, this is not 30# sound so we need to do something different. Instead, in export we will try to 31# preserve the tensor conversion by forcing a non-semantic-breaking aten::_to_copy 32# operator to be traced in the graph, and subsequently banning mutations on all 33# such converted tensors. 34# In addition to patching .to() method call in functionalization, we will have to 35# patch other similar methods like float() and cpu(), because they intentionally 36# don't fall back to .to() methods, but have the same behavior as .to() according to 37# pytorch document. https://pytorch.org/docs/stable/generated/torch.Tensor.float.html 38# thus we simply force them to go through .to() call. 39def _conversion_method_template(**extra_kwargs): 40 def _(self, *args, **kwargs): 41 return self.to(*args, **{**kwargs, **extra_kwargs}) 42 43 return _ 44 45 46class FunctionalTensor(torch.Tensor): 47 """ 48 Functional tensors represent tensors that will remove mutations 49 from a program. If you perform a mutable operation on a functional tensor, 50 it will re-dispatch to the functional variant of that operation. 51 52 Historically, functionalization is implemented in C++ in the dispatcher. 53 This class is a lightweight python shim around the C++ functionalization logic. 54 55 FunctionalTensor is required to be used with a corresponding 56 FunctionalTensormode active, because it relies 57 on using the mode for dispatch (which can properly handle factory functions). 58 """ 59 60 elem: torch.Tensor 61 # Indicates to our torch_dispatch dispatching infra that 62 # this is an "infra" mode with lower dispatching precedence. 63 _mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL 64 65 # Note: The reason we add these extra keys to our FunctionalTensor subclass 66 # is to mirror the behavior of C++ functionalization (we can choose to change this 67 # later, as long as it doesn't break anything). 68 # FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor 69 # to the wrapper, excluding functorch and python dispatch keys. 70 # Here I'm trying to re-use the keyset the functorch wrapper subclasses copy, 71 # except that they don't include ZeroTensor so I'm manually adding it in. 72 _extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add( 73 torch._C.DispatchKey.ZeroTensor 74 ) 75 76 # These are all aten ops that correspond to metadata queries. 77 # We want FunctionalTensor to be able to handle them directly. 78 metadata_fns = [ 79 torch.ops.aten.is_contiguous.default, # type: ignore[has-type] 80 torch.ops.aten.is_contiguous.memory_format, # type: ignore[has-type] 81 torch.ops.aten.is_strides_like_format.default, # type: ignore[has-type] 82 torch.ops.aten.is_non_overlapping_and_dense.default, # type: ignore[has-type] 83 torch.ops.aten.size.default, # type: ignore[has-type] 84 torch.ops.aten.sym_size.default, # type: ignore[has-type] 85 torch.ops.aten.stride.default, # type: ignore[has-type] 86 torch.ops.aten.sym_stride.default, # type: ignore[has-type] 87 torch.ops.aten.storage_offset.default, # type: ignore[has-type] 88 torch.ops.aten.sym_storage_offset.default, # type: ignore[has-type] 89 torch.ops.aten.numel.default, # type: ignore[has-type] 90 torch.ops.aten.sym_numel.default, # type: ignore[has-type] 91 torch.ops.aten.dim.default, # type: ignore[has-type] 92 torch.ops.prim.device.default, # type: ignore[has-type] 93 ] 94 95 # These are ops that claim to be functional, but actually are maybe-mutating/maybe-aliasing 96 # TODO (tmanlaibaatar) make it a tag 97 maybe_aliasing_or_mutating_ops = [ 98 torch.ops.aten.dropout.default, # type: ignore[has-type] 99 torch.ops.aten.batch_norm.default, # type: ignore[has-type] 100 torch.ops.aten.native_batch_norm.default, # type: ignore[has-type] 101 torch.ops.aten._batch_norm_impl_index.default, # type: ignore[has-type] 102 torch.ops.aten.cudnn_batch_norm.default, # type: ignore[has-type] 103 torch.ops.aten.miopen_batch_norm.default, # type: ignore[has-type] 104 torch.ops.aten.atleast_1d.default, # type: ignore[has-type] 105 torch.ops.aten.atleast_2d.default, # type: ignore[has-type] 106 torch.ops.aten.atleast_3d.default, # type: ignore[has-type] 107 torch.ops.aten.cartesian_prod.default, # type: ignore[has-type] 108 torch.ops.aten.conj_physical.default, # type: ignore[has-type] 109 torch.ops.aten.alpha_dropout.default, # type: ignore[has-type] 110 torch.ops.aten.feature_dropout.default, # type: ignore[has-type] 111 torch.ops.aten.feature_alpha_dropout.default, # type: ignore[has-type] 112 torch.ops.aten.unsafe_chunk.default, # type: ignore[has-type] 113 ] 114 115 # Used by auto_functionalize to determine base of tensors during inference mode. 116 _inference_mode_base: Optional["FunctionalTensor"] = None 117 118 def __new__(cls, elem, mode): 119 assert torch._is_functional_tensor(elem) 120 121 # In general, we'd like our functional tensor subclass to only be in charge of functionalization, 122 # and defer to the inner subclass for all other functionality. 123 # Example: If our inner tensor is a ZeroTensor, we would want to defer running the ZeroTensor fallback 124 # until after we redispatch to our inner ZeroTensor. 125 # However, there are a few keys that we need to mirror between the inner and outer tensors. 126 # Conjugate 127 # Negative 128 # Why? These keys are used to test metadata queries, like `.is_conj()` and `.is_neg()`. 129 # We **need** calls to is_conj() to return the same thing on the outer and inner tensors, 130 # Because user code / framework code that branches like so needs to do the same thing 131 # when it sees the outer FunctionalTensor: 132 # if (x.is_conj()) { 133 # return at::view_as_real(x.resolve_conj()); 134 # } else { 135 # return at::view_as_real(x); 136 # } 137 extra_dispatch_keys = ( 138 FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem) 139 ) 140 141 out = torch.Tensor._make_wrapper_subclass( # type: ignore[arg-type, attr-defined] 142 # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. 143 # Calling the overload that has kwargs causes us to go down the first overload path, 144 # which will **always** specialize sizes. 145 # We should probably eventually fix this so that the first overload can just handle dynamic shapes. 146 cls, 147 elem.shape, # sizes 148 elem.stride() if not is_sparse_any(elem) else None, # strides 149 ( 150 elem.storage_offset() if not is_sparse_any(elem) else None 151 ), # storage_offset 152 None, # memory_format 153 elem.dtype, # dtype 154 elem.layout, # layout 155 elem.device, # device 156 False, # pin_memory 157 elem.requires_grad, # requires_grad 158 None, # dispatch_sizes_strides_policy 159 False, # dispatch_device 160 False, # dispatch_layout 161 extra_dispatch_keys, # _extra_dispatch_keys 162 ) 163 torch._C._set_throw_on_mutable_data_ptr(out) 164 out.elem = elem 165 166 if ( 167 torch.is_inference_mode_enabled() 168 and torch._inductor.config.enable_auto_functionalized_v2 169 ): 170 if out.is_base_tensor(): 171 out._inference_mode_base = None 172 # This assumes that the FunctionalTensor.elem does not change its storage after this point. 173 # Otherwise this would be invalid. 174 mode._storage_to_base[out.elem.untyped_storage()] = out 175 else: 176 out._inference_mode_base = mode._storage_to_base[ 177 out.elem.untyped_storage() 178 ] 179 assert out._inference_mode_base is not None 180 return out 181 182 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 183 unrecognized_types = [ 184 t 185 for t in types 186 if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor] 187 ] 188 if unrecognized_types: 189 not_implemented_log.debug( 190 "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types 191 ) 192 return NotImplemented 193 194 if kwargs is None: 195 kwargs = {} 196 197 # FunctionalTensor needs to plumb all metadata requests to the inner tensor. 198 # In theory we don't have to do this - but if we want to service metadata requests here, 199 # we need to carefully make sure all metadata is accurate (including metadata mutations) 200 if func in FunctionalTensor.metadata_fns: 201 # All metadata accesses should be plumbed to the inner tensor, that way we don't have to worry 202 # about the problem of keeping metadata in sync between the wrapper and inner tensor. 203 # This also alleviates us from having to manually handle metadata mutations on the wrapper. 204 assert len(kwargs) == 0 205 if func in [ 206 torch.ops.aten.is_strides_like_format.default, 207 torch.ops.aten.is_contiguous.memory_format, 208 ]: 209 assert len(args) == 2 and isinstance(args[0], FunctionalTensor) 210 return func(torch._from_functional_tensor(args[0].elem), args[1]) 211 assert len(args) == 1 and isinstance(args[0], FunctionalTensor) 212 213 return func(torch._from_functional_tensor(args[0].elem)) 214 # Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up: 215 # - _make_wrapper_subclass requires a __torch_dispatch__ 216 # - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor, 217 # which is of type FunctionalTensorWrapper! We explicitly do not want our wrapper to be a FunctionalTensorWrapper. 218 # - If we use the default tensor.__new__(), we have another problem: it returns inner_tensor.alias(), 219 # which causes every subclass created above autograd to have autograd view metadata 220 # (in addition to also being a FunctionalTensorWrapper). 221 raise RuntimeError( 222 "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" 223 ) 224 225 def __repr__(self): 226 return f"FunctionalTensor({repr(self.elem)})" 227 228 @staticmethod 229 def to_functional(x): 230 # We will do the wrapping for the user. 231 232 assert not torch._is_functional_tensor(x) 233 # The only autograd metadata we care about on the FunctionalTensor is: 234 # - requires_grad (so autograd runs) 235 # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine) 236 # this is handled by FunctionalTensor.to_functional 237 x_functional = torch._to_functional_tensor(x) 238 # Technically the FunctionalTensormode here is unnecessary, 239 # but it avoids spurious NotImplemented logs during `ProxyTorchDispatchMode` tracing. 240 # _mirror_autograd_meta_to queries tensor sizes, 241 # and otherwise the sym_size() call will go to the proxy mode before hitting 242 # FunctionalTensor.__torch_dispatch__ 243 244 functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) 245 assert functional_mode is not None 246 247 with functional_mode: 248 torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined] 249 out = FunctionalTensor(x_functional, functional_mode) 250 torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined] 251 return out 252 253 def from_functional(self): 254 torch._sync(self) 255 return torch._from_functional_tensor(self.elem) 256 257 def is_base_tensor(self) -> bool: 258 return torch._is_functional_tensor_base(self.elem) 259 260 def replace_(self, output) -> None: 261 torch._functionalize_replace(self.elem, output) 262 263 def commit_update(self) -> None: 264 torch._functionalize_commit_update(self.elem) 265 266 def sync(self) -> None: 267 torch._functionalize_sync(self.elem) 268 269 def mark_mutation_hidden_from_autograd(self) -> None: 270 torch._functionalize_mark_mutation_hidden_from_autograd(self.elem) 271 272 def tolist(self) -> Any: 273 if self.elem.dim() == 0: 274 return self.elem.item() 275 elif self.elem.dim() == 1: 276 return [elem.item() for elem in self.elem] 277 else: 278 return [elem.tolist() for elem in self.elem] 279 280 def to(self, *args, **kwargs): 281 if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export: 282 # If copy is specified as pos arg, it's always the second one. 283 if len([arg for arg in args if isinstance(arg, bool)]) <= 1: 284 return super().to(*args, **{**kwargs, "copy": True}) 285 return super().to(*args, **kwargs) 286 287 def cuda(self, device=None, *args, **kwargs): 288 device = device or torch.cuda.current_device() 289 if len(args) > 0: 290 return self.to(device, *args, **kwargs) 291 else: 292 return self.to(device=device, **kwargs) 293 294 char = _conversion_method_template(dtype=torch.int8) 295 cpu = _conversion_method_template(device=torch.device("cpu")) 296 bfloat16 = _conversion_method_template(dtype=torch.bfloat16) 297 byte = _conversion_method_template(dtype=torch.uint8) 298 double = _conversion_method_template(dtype=torch.float64) 299 float = _conversion_method_template(dtype=torch.float32) 300 bool = _conversion_method_template(dtype=torch.bool) 301 half = _conversion_method_template(dtype=torch.float16) 302 int = _conversion_method_template(dtype=torch.int32) 303 long = _conversion_method_template(dtype=torch.int64) 304 305 # TODO(sparse-team): fixes #133174 but can we do without the relay? 306 def to_dense(self): 307 return self.elem.to_dense() 308 309 @property 310 def layout(self): 311 return self.elem.layout 312 313 314class FunctionalTensorMode(TorchDispatchMode): 315 def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False): 316 super().__init__() 317 self.export = export 318 self.is_on_stack = False 319 self.enter_stack = [] 320 # Indicates to our torch_dispatch dispatching infra that 321 # this is an "infra" mode with lower dispatching precedence. 322 self._mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL 323 self.pre_dispatch = pre_dispatch 324 # This will be turned off later for pre-dispatch functionalization 325 self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined] 326 # Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep 327 # track of the ordering between side effectful operations. 328 self._tokens: Dict[Any, torch.Tensor] = {} 329 330 # Filled after forward tracing. 331 self._tokens_forward_output: Dict[Any, torch.Tensor] = {} 332 333 # Functionalization runs twice in AOTAutograd, once in 334 # `run_functionalized_fw_and_collect_metadata` to collect metadata to 335 # see which tensors need to be functionalized and discover how many 336 # tokens we need, and another time in `make_fx` which does the actual 337 # tracing to replace ops with their functional variants and handling 338 # side-effectful ops. In the second stage there should be no token 339 # discovery. This flag distinguishes between the two stages. 340 self._allow_token_discovery = _allow_token_discovery 341 342 self._storage_to_base: weakref.WeakKeyDictionary[ 343 torch.storage.UntypedStorage, Optional[FunctionalTensor] 344 ] = weakref.WeakKeyDictionary() 345 346 # No-op if FunctionalTensorMode is already in use 347 def __enter__(self): 348 def _get_prev_mode(): 349 if self._dispatch_key == torch._C.DispatchKey.PreDispatch: 350 return _get_dispatch_mode_pre_dispatch( 351 torch._C._TorchDispatchModeKey.FUNCTIONAL 352 ) 353 return torch._C._get_dispatch_mode( 354 torch._C._TorchDispatchModeKey.FUNCTIONAL 355 ) 356 357 if _get_prev_mode() is None: 358 self.enter_stack.append(True) 359 return super().__enter__() 360 else: 361 self.enter_stack.append(False) 362 return self 363 364 def __exit__(self, a, b, c): 365 is_on_stack = self.enter_stack.pop() 366 if is_on_stack: 367 super().__exit__(a, b, c) 368 369 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 370 if kwargs is None: 371 kwargs = {} 372 373 if self.export: 374 # We need to make sure that we don't decompose to() as usual in export mode, 375 # because it can get optimized away. Instead we always replace it with _to_copy(). 376 if func == torch.ops.aten.to.dtype_layout: 377 kwargs.pop("copy", None) 378 return self.__torch_dispatch__( 379 torch.ops.aten._to_copy.default, types, args, kwargs 380 ) 381 if func == torch.ops.aten.to.dtype: 382 schema = tuple(arg.name for arg in func._schema.arguments) 383 for arg, name in zip(args[1:], schema[1:]): 384 kwargs[name] = arg 385 kwargs.pop("copy", None) 386 return self.__torch_dispatch__( 387 torch.ops.aten._to_copy.default, types, args[:1], kwargs 388 ) 389 390 unrecognized_types = [ 391 t 392 for t in types 393 if not issubclass(t, torch._subclasses.FakeTensor) 394 and t not in [torch.Tensor, FunctionalTensor] 395 ] 396 397 if unrecognized_types: 398 not_implemented_log.debug( 399 "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types 400 ) 401 return NotImplemented 402 403 def _can_decompose(func): 404 # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832 405 # Never decompose dropout in export 406 if self.export and func == torch.ops.aten.dropout.default: 407 return False 408 409 # We unconditionally decompose ops that are maybe aliasing or mutating ops 410 if func in FunctionalTensor.maybe_aliasing_or_mutating_ops: 411 return True 412 413 # (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops, 414 # because we must know statically of an op mutates or aliasing in order to functionalize it properly 415 # (2) for mutating ops that have CompositeImplicit decomps, we choose to decompose them today. 416 # In theory, we could walk this back and avoid decomposing them later if we need to. 417 alias_info_present = any(arg.alias_info for arg in func._schema.arguments) 418 if alias_info_present or func._schema.is_mutable: 419 return True 420 421 # If we are here, it means we are seeing functional composite op. 422 # For pre-dispatch IR or export inference IR, we wont' decompose them 423 if (self.export or self.pre_dispatch) and func._can_decompose(): 424 if func.namespace not in ["aten", "prim"]: 425 # TODO (tmanlaibaatar) check if the op is PT2 compliant 426 warnings.warn( 427 f"At pre-dispatch tracing, we assume that any custom op marked with " 428 f"CompositeImplicitAutograd and have functional schema are safe to not decompose. " 429 f"Found {func} to be one such op." 430 ) 431 return False 432 433 # in normal torch.compile IR, we decompose functional composite ops 434 return True 435 436 if ( 437 func not in FunctionalTensor.metadata_fns 438 and _can_decompose(func) 439 # Not all funcs from __torch_dispatch__ are actual dispatcher ops, 440 # e.g. prim.device 441 and torch._C._dispatch_has_kernel(func.name()) 442 ): 443 with self: 444 r = func.decompose(*args, **kwargs) 445 if r is not NotImplemented: 446 return r 447 448 def wrap(x): 449 # Only wrap our outputs in subclasses if the inner functionalization call 450 # also wrapped outputs into FunctionalTensorWrappers. 451 # When can this happen? e.g. `torch.div(2, 2)` 452 assert not isinstance(x, FunctionalTensor) 453 if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): 454 return FunctionalTensor(x, self) 455 return x 456 457 def unwrap(x): 458 return x.elem 459 460 from torch._higher_order_ops.auto_functionalize import ( 461 can_auto_functionalize, 462 do_auto_functionalize, 463 do_auto_functionalize_v2, 464 ) 465 466 if can_auto_functionalize( 467 func 468 ) and not torch._C._dispatch_has_kernel_for_dispatch_key( 469 func.name(), torch._C.DispatchKey.Functionalize 470 ): 471 # it doesn't matter what mode we use here because 472 # the implementation of do_auto_functionalize doesn't 473 # interact with FunctionalTensorMode at all 474 if self.export or not inductor_config.enable_auto_functionalized_v2: 475 return do_auto_functionalize(func, args, kwargs) 476 else: 477 return do_auto_functionalize_v2(func, args, kwargs) 478 479 from torch._higher_order_ops.effects import handle_effects, has_effects 480 481 if has_effects(func, args, kwargs): 482 assert not torch._C._dispatch_has_kernel_for_dispatch_key( 483 func.name(), torch._C.DispatchKey.Functionalize 484 ) 485 return handle_effects( 486 self._allow_token_discovery, self._tokens, func, args, kwargs 487 ) 488 489 args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( 490 FunctionalTensor, unwrap, (args, kwargs) 491 ) 492 493 # Expectation: functionalization should not **already** be enabled above our mode. 494 # Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization 495 # to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper. 496 is_included = torch._C._dispatch_tls_is_dispatch_key_included( 497 torch._C.DispatchKey.Functionalize 498 ) 499 is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( 500 torch._C.DispatchKey.Functionalize 501 ) 502 assert is_excluded or not is_included 503 include_to_set = ( 504 torch._C._dispatch_tls_local_include_set() 505 | torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 506 ) 507 exclude_to_set = ( 508 torch._C._dispatch_tls_local_exclude_set().remove( 509 torch._C.DispatchKey.Functionalize 510 ) 511 - FunctionalTensor._extra_dispatch_keys 512 ) 513 514 # All we want to do here is re-use the existing C++ functionalization logic. 515 # This requires swizzling our TLS dispatch keys so that the Functionalize key is active. 516 with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): 517 try: 518 # By default for python functionalization (for AOTAutograd), we reapply views. 519 old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined] 520 521 # Sometimes these functions cannot be directly dispatched to functionalize key 522 # because args are sometimes not functional tensors for some reason? 523 if func in FunctionalTensor.metadata_fns: 524 outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped) 525 outs_wrapped = pytree.tree_map_only( 526 torch.Tensor, wrap, outs_unwrapped 527 ) 528 else: 529 # When we dispatch to the C++ functionalization kernel, we might need to jump back to the 530 # PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath 531 # FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch 532 # from the TLS in order to avoid infinite looping, but this would prevent us from coming 533 # back to PreDispatch later 534 outs_unwrapped = func._op_dk( 535 torch._C.DispatchKey.Functionalize, 536 *args_unwrapped, 537 **kwargs_unwrapped, 538 ) 539 # We don't allow any mutation on result of dropout or _to_copy 540 if self.export: 541 if func in ( 542 torch.ops.aten.dropout.default, 543 torch.ops.aten._to_copy.default, 544 ): 545 torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] 546 outs_wrapped = pytree.tree_map_only( 547 torch.Tensor, wrap, outs_unwrapped 548 ) 549 finally: 550 torch._disable_functionalization() 551 torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined] 552 553 is_included = torch._C._dispatch_tls_is_dispatch_key_included( 554 torch._C.DispatchKey.Functionalize 555 ) 556 is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( 557 torch._C.DispatchKey.Functionalize 558 ) 559 assert is_excluded or not is_included 560 561 if ( 562 # If no outputs are our functional subclass, then don't try to fix up aliasing 563 not any( 564 isinstance(x, FunctionalTensor) 565 for x in pytree.tree_leaves(outs_wrapped) 566 ) 567 # Since lift_fresh lifts its argument into a functional tensor, we can skip the 568 # aliasing correction step. Otherwise, we would be setting the storage of a 569 # lifted tensor to that of an unlifted tensor. 570 # Ref: https://github.com/pytorch/pytorch/issues/111506 571 or func == torch.ops.aten.lift_fresh.default 572 ): 573 return outs_wrapped 574 # for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper 575 if ( 576 torch.Tag.inplace_view in func.tags 577 and func is not torch.ops.aten.set_.source_Tensor 578 ): 579 with torch.utils._mode_utils.no_dispatch(): 580 func(*args, **kwargs) 581 # Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing. 582 # inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects. 583 # Use this util to figure out the right thing to return. 584 # If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for. 585 return return_and_correct_aliasing(func, args, kwargs, outs_wrapped) 586 587 @classmethod 588 def is_infra_mode(cls) -> bool: 589 return True 590 591 592@contextlib.contextmanager 593def disable_functional_mode(): 594 return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) 595 596 597# This is similar to torch.func.functionalize, but: 598# - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass). 599# One important advantage to using this mode is that it will let us 600# run functionalization underneath __torch_dispatch__, 601# which we need in AOTAutograd. 602# - Doing so means that it does not automatically compose with other 603# functorch transforms, since these transforms always run above __torch_dispatch__. 604# That's why this util lives here, and not in functorch. 605def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()): 606 # TODO: pull these from aot autograd 607 def to_fun(t): 608 if isinstance(t, torch.Tensor): 609 return FunctionalTensor.to_functional(t) 610 return t 611 612 def from_fun(t): 613 if not isinstance(t, FunctionalTensor): 614 # quick sanity assert 615 if isinstance(t, torch.Tensor): 616 assert not torch._is_functional_tensor(t) 617 return t 618 torch._sync(t) 619 return torch._from_functional_tensor(t.elem) 620 621 def inner(*args, **kwargs): 622 disable_above = torch._C._ExcludeDispatchKeyGuard( 623 torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 624 ) 625 with disable_above, mode: 626 func_args = pytree.tree_map_only(torch.Tensor, to_fun, args) 627 func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs) 628 func_outputs = func(*func_args, **func_kwargs) 629 outputs = pytree.tree_map_only(FunctionalTensor, from_fun, func_outputs) 630 631 return outputs 632 633 return inner 634 635 636class BaseFunctionalizeAPI(ABC): 637 @abstractmethod 638 def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: 639 pass 640 641 @abstractmethod 642 def unwrap_tensors( 643 self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] 644 ) -> Any: 645 pass 646 647 @abstractmethod 648 def functionalize(self, inner_f: Callable) -> Callable: 649 pass 650 651 @abstractmethod 652 def redispatch_to_next(self) -> ContextManager: 653 pass 654 655 @abstractmethod 656 def replace(self, input_tensor, output_tensor) -> None: 657 pass 658 659 @abstractmethod 660 def commit_update(self, tensor) -> None: 661 pass 662 663 @abstractmethod 664 def sync(self, tensor) -> None: 665 pass 666 667 @abstractmethod 668 def mark_mutation_hidden_from_autograd(self, tensor) -> None: 669 pass 670 671 672class PythonFunctionalizeAPI(BaseFunctionalizeAPI): 673 def __init__( 674 self, mode: Optional[FunctionalTensorMode] = None, pre_dispatch: bool = False 675 ) -> None: 676 super().__init__() 677 self.mode = mode if mode else FunctionalTensorMode() 678 self.pre_dispatch = pre_dispatch 679 680 def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: 681 with self.mode: 682 return torch.utils._pytree.tree_map_only( 683 torch.Tensor, FunctionalTensor.to_functional, args 684 ) 685 686 def unwrap_tensors( 687 self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor]] 688 ) -> Any: 689 return torch.utils._pytree.tree_map_only( 690 FunctionalTensor, FunctionalTensor.from_functional, args 691 ) 692 693 def functionalize(self, inner_f: Callable) -> Callable: 694 return dispatch_functionalize(inner_f, self.mode) 695 696 def redispatch_to_next(self) -> ContextManager: 697 # [NOTE] We don't do anything here because at the time 698 # we exercise this path, we would have already popped the 699 # FunctionalTensorMode from mode stack. Since FunctionalTensorMode 700 # is now stateful, it is better to explicitly pass in correct mode 701 # directly instead of globally setting it. 702 return contextlib.nullcontext() 703 704 def replace(self, input_tensor, output_tensor) -> None: 705 assert isinstance(input_tensor, FunctionalTensor) 706 assert not isinstance(output_tensor, FunctionalTensor) 707 input_tensor.replace_(output_tensor) 708 709 def commit_update(self, tensor) -> None: 710 assert isinstance(tensor, FunctionalTensor) 711 tensor.commit_update() 712 713 def sync(self, tensor) -> None: 714 assert isinstance(tensor, FunctionalTensor) 715 tensor.sync() 716 717 def mark_mutation_hidden_from_autograd(self, tensor) -> None: 718 assert isinstance(tensor, FunctionalTensor) 719 tensor.mark_mutation_hidden_from_autograd() 720 721 722class CppFunctionalizeAPI(BaseFunctionalizeAPI): 723 def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: 724 from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional 725 726 return _wrap_all_tensors_to_functional(args, level=0) 727 728 def unwrap_tensors( 729 self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] 730 ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: 731 from torch._functorch.eager_transforms import ( 732 _unwrap_all_tensors_from_functional, 733 ) 734 735 return _unwrap_all_tensors_from_functional(args, reapply_views=_reapply_views()) 736 737 def functionalize(self, inner_f: Callable) -> Callable: 738 return torch.func.functionalize(inner_f) 739 740 def redispatch_to_next(self) -> ContextManager: 741 return torch._C._ExcludeDispatchKeyGuard( 742 torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) 743 ) 744 745 def replace(self, input_tensor, output_tensor) -> None: 746 torch._functionalize_replace(input_tensor, output_tensor) 747 748 def commit_update(self, tensor) -> None: 749 torch._functionalize_commit_update(tensor) 750 751 def sync(self, tensor) -> None: 752 torch._functionalize_sync(tensor) 753 754 def mark_mutation_hidden_from_autograd(self, tensor) -> None: 755 torch._functionalize_mark_mutation_hidden_from_autograd(tensor) 756 757 758class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI): 759 def __init__(self, interpreter): 760 self.interpreter = interpreter 761 762 def wrap_tensors(self, args: Tuple[Any]) -> Tuple[Any]: 763 from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional 764 765 return _wrap_all_tensors_to_functional(args, level=self.interpreter.level()) 766 767 def unwrap_tensors( 768 self, args: Union[torch.Tensor, Tuple[torch.Tensor, ...]] 769 ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: 770 from torch._functorch.eager_transforms import ( 771 _unwrap_all_tensors_from_functional, 772 ) 773 774 return _unwrap_all_tensors_from_functional( 775 args, reapply_views=self.interpreter.functionalize_add_back_views() 776 ) 777 778 def functionalize(self, inner_f: Callable) -> Callable: 779 return torch.func.functionalize( 780 inner_f, 781 remove=( 782 "mutations_and_views" 783 if self.interpreter.functionalize_add_back_views() 784 else "mutations" 785 ), 786 ) 787 788 def redispatch_to_next(self) -> ContextManager: 789 return self.interpreter.lower() 790 791 def replace(self, input_tensor, output_tensor) -> None: 792 torch._functionalize_replace(input_tensor, output_tensor) 793 794 def commit_update(self, tensor) -> None: 795 torch._functionalize_commit_update(tensor) 796 797 def sync(self, tensor) -> None: 798 torch._functionalize_sync(tensor) 799 800 def mark_mutation_hidden_from_autograd(self, tensor) -> None: 801 torch._functionalize_mark_mutation_hidden_from_autograd(tensor) 802 803 804def mb_unwrap_functional_tensor(tensor: torch.Tensor): 805 if isinstance(tensor, FunctionalTensor): 806 return torch._from_functional_tensor(tensor.elem) 807 return tensor 808