1# mypy: allow-untyped-decorators 2from __future__ import annotations 3 4import atexit 5import contextlib 6import dataclasses 7import functools 8import logging 9import math 10import os 11import traceback 12import typing 13import weakref 14from collections import defaultdict 15from dataclasses import dataclass 16from typing import ( 17 Any, 18 Callable, 19 cast, 20 Dict, 21 Generator, 22 Iterable, 23 List, 24 Literal, 25 Mapping, 26 Optional, 27 Sequence, 28 Set, 29 Tuple, 30 Type, 31 TYPE_CHECKING, 32 TypeVar, 33 Union, 34) 35from typing_extensions import Self, TypeGuard 36from weakref import ReferenceType 37 38import torch 39from torch import SymBool, SymFloat, SymInt, Tensor 40from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor 41from torch._prims_common import suggest_memory_format 42from torch._subclasses.meta_utils import ( 43 assert_eq, 44 assert_metadata_eq, 45 is_sparse_any, 46 is_sparse_compressed, 47 MetaConverter, 48) 49from torch._utils import render_call 50from torch.fx.immutable_collections import immutable_dict 51from torch.fx.operator_schemas import normalize_function 52from torch.multiprocessing.reductions import StorageWeakRef 53from torch.overrides import TorchFunctionMode 54from torch.types import IntLikeType, py_sym_types 55from torch.utils._backport_slots import dataclass_slots 56from torch.utils._mode_utils import no_dispatch 57from torch.utils._python_dispatch import ( 58 is_traceable_wrapper_subclass, 59 TorchDispatchMode, 60) 61from torch.utils._pytree import PyTree, tree_map, tree_map_, TreeSpec 62from torch.utils._stats import count 63from torch.utils._traceback import CapturedTraceback 64 65from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputStub 66 67 68if TYPE_CHECKING: 69 from types import TracebackType 70 71 from torch._guards import Source 72 from torch._ops import OpOverload 73 from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext 74 75log = logging.getLogger(__name__) 76 77# TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186 78# Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105 79try: 80 not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") 81except ValueError as e: 82 if "'not_implemented' not registered" in str(e): 83 import logging as not_implemented_log 84 else: 85 raise e 86 87 88class _Unassigned: 89 pass 90 91 92_UNASSIGNED = _Unassigned() 93 94DimList = List 95 96pytree = torch.utils._pytree 97T = TypeVar("T") 98 99aten = torch._ops.ops.aten 100 101CONSTANT_NUMEL_LIMIT = 1 102 103RECURSION_COUNT = 0 104 105 106# Small helper that increments recursion count, and 107# resets it when the object goes out of scope. Useful 108# if you don't want to increase indentation which is 109# what a context manager would do. 110class IncrementRecursionCount: 111 def __init__(self) -> None: 112 global RECURSION_COUNT 113 RECURSION_COUNT += 1 114 115 def __del__(self) -> None: 116 global RECURSION_COUNT 117 RECURSION_COUNT -= 1 118 119 120@dataclass 121class UnsupportedFakeTensorException(RuntimeError): 122 reason: str 123 124 125@dataclass 126class DynamicOutputShapeException(RuntimeError): 127 func: OpOverload 128 129 130@dataclass 131class DataDependentOutputException(RuntimeError): 132 func: OpOverload 133 134 135@dataclass 136class UnsupportedOperatorException(RuntimeError): 137 func: OpOverload 138 139 140def ordered_set(*items: T) -> Dict[T, Literal[True]]: 141 return dict.fromkeys(items, True) 142 143 144@contextlib.contextmanager 145def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, None]: 146 old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) 147 try: 148 yield old 149 finally: 150 if old is not None: 151 torch._C._set_dispatch_mode(old) 152 153 154def get_plain_tensors(subclass: Tensor) -> List[Tensor]: 155 assert is_traceable_wrapper_subclass(subclass) 156 plain_tensors: List[Tensor] = [] 157 todo = [subclass] 158 while todo: 159 curr = todo.pop() 160 if not is_traceable_wrapper_subclass(curr): 161 assert isinstance(curr, Tensor) 162 plain_tensors.append(curr) 163 continue 164 165 inner_keys, _ = curr.__tensor_flatten__() 166 for key in reversed(inner_keys): 167 todo.append(getattr(curr, key)) 168 169 return plain_tensors 170 171 172def is_fake(x: object) -> TypeGuard[Tensor]: 173 if isinstance(x, FakeTensor): 174 return True 175 if is_traceable_wrapper_subclass(x): 176 attrs, _ = type(x).__tensor_flatten__(x) 177 flattened_tensors = [getattr(x, attr) for attr in attrs] 178 all_fake = all(is_fake(x) for x in flattened_tensors) 179 any_fake = any(is_fake(x) for x in flattened_tensors) 180 assert all_fake == any_fake, "got mixed fake and real tensors!" 181 return all_fake 182 elif isinstance(x, Tensor) and torch._is_functional_tensor(x): 183 reapply_views = torch._C._functionalization_reapply_views_tls() 184 unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views) 185 return is_fake(unwrapped) 186 elif isinstance(x, Tensor) and is_functorch_wrapped_tensor(x): 187 unwrapped = torch._C._functorch.get_unwrapped(x) 188 return is_fake(unwrapped) 189 return False 190 191 192def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]: 193 if isinstance(t, FakeTensor): 194 return t.fake_mode 195 if is_traceable_wrapper_subclass(t): 196 inner_tensor_names, _ = t.__tensor_flatten__() 197 modes = [ 198 maybe_get_fake_mode(getattr(t, t_name)) for t_name in inner_tensor_names 199 ] 200 m = modes[0] 201 assert all(m is x for x in modes) 202 return m 203 elif isinstance(t, Tensor) and torch._is_functional_tensor(t): 204 reapply_views = torch._C._functionalization_reapply_views_tls() 205 unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views) 206 return maybe_get_fake_mode(unwrapped) 207 elif isinstance(t, Tensor) and is_functorch_wrapped_tensor(t): 208 unwrapped = torch._C._functorch.get_unwrapped(t) 209 return maybe_get_fake_mode(unwrapped) 210 return None 211 212 213@functools.lru_cache(None) 214def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: 215 return torch._C._SchemaInfo(func._schema) 216 217 218# many of the decompositions registered to torch/_prims do not at the moment model 219# aliasing or strides, so as an incremental step, just enable the decompositions in 220# torch/_decomp/decompositions.py. 221# decomps are used for aot autograd tracing so we would like to unify on their 222# implementation and add additional testing to them 223@functools.lru_cache(None) 224def torch_decomp_decompositions(func: OpOverload) -> bool: 225 from torch._decomp import decomposition_table 226 227 decompositions = torch._decomp.decompositions 228 # Note that the function in the decomposition table might be 229 # different from the one in the module because of the difference 230 # in out handling in aten API and torch public API 231 return decomposition_table[func].__module__.startswith( 232 "torch._decomp" 233 ) and decomposition_table[func].__name__ in dir(decompositions) 234 235 236def tree_flatten_only(ty: Type[T], tree: PyTree) -> List[T]: 237 flat_vals = pytree.tree_leaves(tree) 238 return [elem for elem in flat_vals if isinstance(elem, ty)] 239 240 241def _is_plain_tensor(t: object) -> bool: 242 return ( 243 type(t) is Tensor 244 and t.layout == torch.strided 245 and not ( 246 t.is_sparse 247 or t.is_nested 248 or is_functorch_wrapped_tensor(t) 249 or is_legacy_batchedtensor(t) 250 or torch._is_functional_tensor(t) 251 ) 252 ) 253 254 255# Similar to `MetaConverter`, this is a class for converting 256# multiple tensors into fake tensors which share the same view/storage 257# structure. Like `MetaConverter`, it uses `WeakIdRef` to 258# hold a weak reference for all memoized tensors. 259class FakeTensorConverter: 260 @property 261 def tensor_memo( 262 self, 263 ) -> weakref.WeakValueDictionary: 264 # not valid until py3.10 265 # weakref.WeakValueDictionary["torch._subclasses.meta_utils.MetaTensorId", Optional["FakeTensor"]] 266 return self.meta_converter.tensor_memo 267 268 meta_converter: MetaConverter 269 constant_storage_mapping: Dict[StorageWeakRef, List[ReferenceType]] 270 export: bool 271 272 def __init__(self, *, copy_data: bool = False, export: bool = False) -> None: 273 self.meta_converter = MetaConverter(copy_data=copy_data) 274 self.export = export 275 276 # map from to storage to corresponding constant tensors 277 self.constant_storage_mapping = {} 278 279 def add_constant_storage_mapping(self, fake_tensor: FakeTensor) -> None: 280 # when you have a constant, aliased tensor: 281 # const_tensor.add_(torch.rand([1])) 282 # all aliases of it must become no longer const 283 assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None 284 weak_st = StorageWeakRef(fake_tensor.constant._typed_storage()) 285 286 # we need a map from a weak storage to all of its corresponding 287 # constant tensors. python doesn't have the weak value equivalent 288 # of defaultdict(list), so we are using a WeakValueDictionary as one 289 if weak_st not in self.constant_storage_mapping: 290 self.constant_storage_mapping[weak_st] = [] 291 self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor)) 292 293 def invalidate_constant_aliases(self, tensor: Tensor) -> None: 294 assert not isinstance(tensor, FakeTensor) 295 296 weak_st = StorageWeakRef(tensor._typed_storage()) 297 if weak_st not in self.constant_storage_mapping: 298 return 299 300 for weak_tensor_ref in self.constant_storage_mapping[weak_st]: 301 ten = weak_tensor_ref() 302 if ten is not None: 303 ten._fix_weakref() 304 ten.constant = None 305 306 del self.constant_storage_mapping[weak_st] 307 308 def _get_memo(self, t: Tensor) -> Optional[FakeTensor]: 309 tid = self.meta_converter.describer.lookup_tensor.get(t) 310 if tid is None: 311 return None 312 return self.tensor_memo.get(tid) 313 314 def set_tensor_memo(self, t: Tensor, v: FakeTensor) -> None: 315 tid = self.meta_converter.describer.get_tensor_id(t) 316 self.meta_converter.tensor_memo[tid] = v 317 318 # You can have a real tensor that you need to convert into a fake tensor. 319 # If you have a meta tensor already, call from_meta_and_device. 320 # 321 # You're allowed to pass a meta tensor to be turned into a fake 322 # tensor; although an odd thing to do, this can occur if you're doing 323 # cross ref testing and the inner test is already operating on meta tensors. 324 def from_real_tensor( 325 self, 326 fake_mode: FakeTensorMode, 327 t: Tensor, 328 make_constant: bool = False, 329 shape_env: Optional[ShapeEnv] = None, 330 *, 331 source: Optional[Source] = None, 332 symbolic_context: Optional[SymbolicContext] = None, 333 trace: bool = True, 334 ) -> FakeTensor: 335 # see note [Tensor Fakification and Symbol Caching] 336 if not symbolic_context and not source and shape_env: 337 if tracing_context := torch._guards.TracingContext.try_get(): 338 if t in tracing_context.tensor_to_context: 339 symbolic_context = tracing_context.tensor_to_context[t] 340 from torch.fx.experimental.symbolic_shapes import ( 341 StatefulSymbolicContext, 342 ) 343 344 assert isinstance(symbolic_context, StatefulSymbolicContext) 345 source = symbolic_context.tensor_source 346 347 maybe_memo = self._get_memo(t) 348 if maybe_memo is not None: 349 return maybe_memo 350 existing_device = t.device 351 # not yet supported in metatensors 352 if t.is_quantized: 353 raise UnsupportedFakeTensorException("quantized nyi in meta tensors") 354 if type(t) is torch.nn.Parameter: 355 assert not make_constant 356 357 def mk_fake_tensor(make_meta_t: Callable[[], object]) -> FakeTensor: 358 # NB: don't use in_kernel_invocation_manager. to 359 # ensure FakeTensor can internally do constant computation 360 # as necessary. Invocation manager is "more correct" as 361 # it works for more operators in make_meta_t, but 362 # invariant is that make_meta_t only calls factories 363 # for which it is not strictly necessary to use the 364 # invocation manager (I think!) 365 with no_dispatch(): 366 return FakeTensor( 367 fake_mode, 368 make_meta_t(), 369 existing_device, 370 # TODO: callback might be used in recursive contexts, in 371 # which case using t is wrong! BUG! 372 constant=t if make_constant else None, 373 ) 374 375 out = self.meta_converter( 376 t, 377 shape_env=shape_env, 378 callback=mk_fake_tensor, 379 source=source, 380 symbolic_context=symbolic_context, 381 trace=trace, 382 ) 383 if out is NotImplemented: 384 raise UnsupportedFakeTensorException("meta converter nyi") 385 386 from torch._dynamo.source import RandomValueSource 387 388 value = None 389 if ( 390 not self.export 391 and _is_plain_tensor(t) # mostly, we want to know if item() works 392 and t.dim() == 0 393 and t.device.type == "cpu" 394 # All integer types are fair game, because signed overflow is UB 395 # (and even int64 can overflow, since integers in Python are 396 # arbitrary precision). But only float64 is OK for float, because 397 # switching between float32 and float64 changes semantics in an 398 # observable way without hitting UB. 399 and t.dtype 400 in [torch.int64, torch.int32, torch.int16, torch.int8, torch.float64] 401 and source is not None 402 # Impede setting up item() on things coming from random. These 403 # are not "real" item() calls, instead UnspecializedPythonVariable 404 # is unsafely pretending an int is a tensor, which can sometimes 405 # implicitly cause an item call. The problem is this is pretty 406 # unsound: there's no reason substituting an int with a Tensor is 407 # going to give the same results. Today, you mostly get around 408 # this by typically not having capture_scalar_outputs on and graph 409 # breaking when someone tries to use the unspec variable in an 410 # int-y context. But allowing it through here would break that. 411 # So don't. 412 # 413 # Once random values are setup to be represented as 414 # SymNodeVariable, this condition can be removed. To check if 415 # you've done it right, this is a good test: 416 # 417 # PYTORCH_TEST_WITH_DYNAMO=1 python test/test_reductions.py -k 418 # TestReductionsCPU.test_dim_reduction_fns_fn_name_amax_cpu_bfloat16 419 and not isinstance(source, RandomValueSource) 420 # In Dynamo, shape_env is never none (even with static shapes). 421 # However, FakeTensorMode can be used by hand and in some cases 422 # ShapeEnv is not allocated. 423 and shape_env is not None 424 ): 425 from torch._dynamo.source import CallMethodItemSource, FloatTensorSource 426 from torch.fx.experimental.symbolic_shapes import DimDynamic 427 428 with no_dispatch(): 429 value = t.item() 430 if not math.isnan(value): 431 # Peephole strip out unnecessary torch.as_tensor(x).item() 432 if isinstance(source, FloatTensorSource): 433 item_source = source.base 434 else: 435 item_source = CallMethodItemSource(source) 436 symbol = shape_env.create_unspecified_symbol( 437 value, 438 source=item_source, 439 dynamic_dim=DimDynamic.DYNAMIC, 440 ) 441 # NB: reusing item_memo here ensures that we invalidate on 442 # mutation 443 if t.dtype == torch.int64: 444 out.item_memo = shape_env.create_symintnode( 445 symbol, 446 hint=value, 447 source=item_source, 448 ) 449 elif t.dtype == torch.float64: 450 out.item_memo = shape_env.create_symfloatnode( 451 symbol, 452 hint=value, 453 source=item_source, 454 ) 455 if make_constant: 456 self.add_constant_storage_mapping(out) 457 # NB: meta_converter set the memo 458 return out 459 460 # If you specify the device, it MUST be a meta tensor. 461 def from_meta_and_device( 462 self, fake_mode: FakeTensorMode, t: Tensor, device: torch.device 463 ) -> FakeTensor: 464 assert ( 465 t.device.type == "meta" 466 ), f"tensor's device must be `meta`, got {t.device.type} instead" 467 # This is a bit abusive (this is not the "real" tensor) but whatever, 468 # the meta tensor should be fresh so there's no way to get it wrong 469 maybe_memo = self._get_memo(t) 470 if maybe_memo is not None: 471 return maybe_memo 472 out = FakeTensor(fake_mode, t, device) 473 self.set_tensor_memo(t, out) 474 return out 475 476 477@functools.lru_cache(None) 478def init_gpu_context() -> None: 479 # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first 480 if torch.cuda.is_available(): 481 ( 482 torch.empty(1, device="cuda") 483 if torch.version.hip is None 484 else torch.zeros(1, device="cuda") 485 ) 486 487 if torch.xpu.is_available(): 488 (torch.empty(1, device="xpu")) 489 490 491@contextlib.contextmanager 492def in_kernel_invocation_manager( 493 fake_mode: FakeTensorMode, 494) -> Generator[None, None, None]: 495 # See: note [Fake Tensor Dispatch Keys] 496 prev_in_kernel = fake_mode.in_kernel_invocation 497 meta_in_tls = torch._C._meta_in_tls_dispatch_include() 498 assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}" 499 500 with torch._C._DisableTorchDispatch(): 501 fake_mode.in_kernel_invocation = True 502 # Unfortunately _set_meta_in_tls_dispatch_include(False) can leave 503 # `Dense` turned on (because it's implied by `Meta`) 504 with torch._C._PreserveDispatchKeyGuard(): 505 torch._C._set_meta_in_tls_dispatch_include(True) 506 try: 507 yield 508 finally: 509 fake_mode.in_kernel_invocation = prev_in_kernel 510 # torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel) 511 512 513# Return if the function allows Python numbers to bind to Tensors 514def should_allow_numbers_as_tensors(func: OpOverload) -> bool: 515 return torch._C._should_allow_numbers_as_tensors( 516 func.name().split("::")[-1].split(".")[0] 517 ) 518 519 520class FakeTensorConfig: 521 debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1" 522 523 524# This memorizes the unbacked SymInt representing quantities like the number 525# of nonzero elements in this tensor. There is one instance of the descriptor 526# per particular quantity to memoize. 527# 528# Memoization is helpful if you do something like x[mask] and y[mask]; 529# mask.nonzero() gets repeatedly called and should give a consistent unbacked 530# SymInt. It needs to be invalidated in the same way constant is. 531# 532# Making this a descriptor may seem overly fancy, but actually it's the most 533# convenient way to make sure we have access to FakeTensor during access, 534# which is required for testing version counter and epoch validity 535class SymIntMemoDescriptor: 536 _name: str 537 538 # By default, SymInts in this memo are invalidated across versions/epochs. 539 # nested_ints however are preserved across epochs and across versions. 540 # Preserving across versions is okay for nested int since the association 541 # of a nested int is agnostic to the underlying data and nested ints are not 542 # shared across multiple distinct tensors. 543 _is_nested_int: bool 544 545 def __init__(self, *, is_nested_int: bool = False) -> None: 546 self._is_nested_int = is_nested_int 547 548 def __set_name__(self, owner: str, name: str) -> None: 549 self._name = name 550 551 def _memo(self, obj: FakeTensor) -> str: 552 return f"_{self._name}" 553 554 def _memo_vc(self, obj: FakeTensor) -> str: 555 return f"_{self._name}_vc" 556 557 # When we retrace, we need to invalidate all the memos so that we can 558 # accurately identify the first time unbacked SymInts are allocated. 559 # This is only relevant for inputs; for intermediates, they will get fresh 560 # fake tensors so you won't have a memo anyway 561 def _memo_epoch(self, obj: FakeTensor) -> str: 562 return f"_{self._name}_epoch" 563 564 def __get__( 565 self, obj: FakeTensor, objtype: Optional[Type[FakeTensor]] = None 566 ) -> Optional[torch.SymInt]: 567 if (r := getattr(obj, self._memo(obj))) is None: 568 return None 569 # Version counter based tracking isn't 100% sound but it's close 570 # enough 571 if ( 572 not self._is_nested_int and getattr(obj, self._memo_vc(obj)) != obj._version 573 ) or ( 574 not self._is_nested_int 575 and getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch 576 ): 577 setattr(obj, self._memo(obj), None) 578 return None 579 return r 580 581 def __set__(self, obj: FakeTensor, value: Optional[torch.SymInt]) -> None: 582 if value is None: 583 setattr(obj, self._memo(obj), None) 584 setattr(obj, self._memo_vc(obj), None) 585 setattr(obj, self._memo_epoch(obj), None) 586 elif not obj.is_inference() or self._is_nested_int: 587 setattr(obj, self._memo(obj), value) 588 if not self._is_nested_int: 589 setattr(obj, self._memo_vc(obj), obj._version) 590 setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch) 591 592 593class FakeTensor(Tensor): 594 """ 595 Meta tensors give you the ability to run PyTorch code without having to 596 actually do computation through tensors allocated on a `meta` device. 597 Because the device is `meta`, meta tensors do not model device propagation. 598 FakeTensor extends MetaTensors to also carry an additional `fake_device` 599 which tracks devices that would have been used. 600 """ 601 602 fake_device: torch.device 603 fake_mode: FakeTensorMode 604 constant: Optional[Tensor] 605 real_tensor: Optional[Tensor] 606 607 # TODO: Generalize this as needed, e.g., into a trie of memos, if 608 # you do something like x[0].item() (x[0] is fresh each time, so 609 # memo mechanism here won't work) 610 nonzero_memo = SymIntMemoDescriptor() 611 item_memo = SymIntMemoDescriptor() 612 unique_memo = SymIntMemoDescriptor() 613 614 # We expect nested_int_memo to be None when an offsets is a graph 615 # intermediate, or an input that has never been associated with a 616 # nested int. 617 nested_int_memo = SymIntMemoDescriptor(is_nested_int=True) 618 619 # Indicates to our torch_dispatch dispatching infra that 620 # this is an "infra" mode with lower dispatching precedence. 621 _mode_key = torch._C._TorchDispatchModeKey.FAKE 622 623 @property 624 def device(self) -> torch.device: 625 if self.fake_mode.in_kernel_invocation: 626 return torch.device("meta") 627 else: 628 return self.fake_device 629 630 @device.setter 631 def device(self, _: torch.device) -> None: 632 raise NotImplementedError 633 634 # Note: [Fake Tensor Dispatch Keys] 635 # In order to model the behavior of device-specific autocast 636 # and autograd logic, we update the dispatch keys of FakeTensors 637 # to reflect their fake device. This includes the BackendComponent 638 # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent 639 # related Autocast and Autograd keys. __torch_dispatch__ sits below 640 # Autocast and Autograd, and is only invoked when we are at the 641 # kernel for the BackendComponent. Then, we add Meta to the 642 # thread-local dispatch include set to hit the meta kernel 643 # instead of the kernel of the BackendComponent for the fake device. 644 # The `device_for_backend_keys` does that below 645 # NOTE: this probably will not do the right thing for backends 646 # that have dispatch keys which are higher than the "meta" key: 647 # https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L189 648 649 # We don't support named tensors; graph break 650 @property 651 def names(self) -> List[str]: 652 raise UnsupportedFakeTensorException( 653 "torch.compile doesn't support named tensors" 654 ) 655 656 @names.setter 657 def names(self, _: List[str]) -> None: 658 raise NotImplementedError 659 660 @staticmethod 661 def __new__( 662 cls, 663 fake_mode: FakeTensorMode, 664 elem: Tensor, 665 device: torch.device, 666 constant: Optional[Tensor] = None, 667 real_tensor: Optional[Tensor] = None, 668 ) -> Self: 669 self = Tensor._make_subclass( 670 cls, 671 elem, 672 elem.requires_grad, 673 dispatch_device=True, 674 device_for_backend_keys=device, 675 ) 676 if not fake_mode._allow_unsafe_data_ptr_access: 677 torch._C._set_throw_on_mutable_data_ptr(self) 678 else: 679 torch._C._set_warn_deprecated_on_mutable_data_ptr(self) 680 681 assert elem.device.type == "meta", elem.device.type 682 device = device if isinstance(device, torch.device) else torch.device(device) 683 # NB: it is fine, if a little confusing, for device to be meta 684 # (we are faking a meta tensor in that case). However, it often 685 # indicates some sort of confusion (e.g., you accidentally passed 686 # in a meta tensor when you should have passed in the real tensor). 687 # So by default we disallow meta, and if you are working in a situation 688 # where it is helpful (e.g., crossref testing) you can turn it back 689 # on 690 if not fake_mode.allow_meta: 691 assert device.type != "meta" 692 # normalize device. 693 if device.type in ["cuda", "xpu"]: 694 init_gpu_context() 695 696 if ( 697 device.type 698 in ["cuda", "hpu", "xpu", torch._C._get_privateuse1_backend_name()] 699 and device.index is None 700 ): 701 if getattr(torch, device.type).is_initialized(): 702 device = torch.device( 703 f"{device.type}:{getattr(torch, device.type).current_device()}" 704 ) 705 else: 706 device = torch.device(f"{device.type}:0") 707 self.fake_device = device 708 self.fake_mode = fake_mode 709 self.constant = constant 710 assert not isinstance(real_tensor, FakeTensor) 711 self.real_tensor = real_tensor 712 self.nonzero_memo = None 713 self.item_memo = None 714 self.unique_memo = None 715 self.nested_int_memo = None 716 717 if FakeTensorConfig.debug: 718 self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined] 719 return self 720 721 # In some circumstances, a conventional Tensor constructor 722 # will get rewritten to call into FakeTensor. We must provide an 723 # __init__ method that can accept the Python interpreters initialization 724 # in such a situation; we must also be able to handle direct fake 725 # tensor construction via FakeTensor(). 726 # 727 # In particular, the __init__ call will look funny in the following case: 728 # 729 # with FakeTensorMode(): 730 # x = Tensor([1, 2, 3]) 731 # 732 # this desugars into: 733 # 734 # with FakeTensorMode(): 735 # x = Tensor.__new__([1, 2, 3]) 736 # # NB: x is a fake tensor, because of the mode! 737 # x.__init__([1, 2, 3]) # not the normal fake tensor args! 738 # 739 def __init__(self, *args: object, **kwargs: object) -> None: 740 super().__init__() 741 742 @staticmethod 743 def from_tensor(t: Tensor, fake_mode: FakeTensorMode) -> FakeTensor: 744 return fake_mode.from_tensor(t) 745 746 @classmethod 747 @count 748 def __torch_dispatch__( 749 cls, 750 func: OpOverload, 751 types: Sequence[Type], 752 args: Sequence[object] = (), 753 kwargs: Mapping[str, object] = immutable_dict(), 754 ) -> object: 755 # need to handle here to avoid infinite recursion 756 # see [in_kernel_invocation] 757 if func == torch.ops.prim.device.default: 758 assert len(args) == 1 and isinstance(args[0], FakeTensor) 759 if args[0].fake_mode.in_kernel_invocation: 760 return torch.device("meta") 761 else: 762 return args[0].fake_device 763 764 # this handler must be done inside FakeTensor subclass, not mode, because 765 # we can end up dispatching here when we have a fake tensor with 766 # symbolic sizes running under in_kernel_invocation_manager. 767 # The subclass is asked to handle this query because size (not 768 # sym_size) was called, but we are unable to serve it directly because 769 # there are symbolic sizes in the class. The use of 770 # in_kernel_invocation_manager means it's incorrect to activate a 771 # mode to actually handle this (this caused 772 # https://github.com/pytorch/pytorch/issues/122772). 773 if handler := _DISPATCH_META_HANDLERS.get(func): 774 return handler(args) 775 776 # Because fake mode can return NotImplemented (if it sees a subclass 777 # it doesn't know how to deal with), this test here is important 778 # because the next dispatch after a fake mode will attempt to use 779 # subclasses of tensors to dispatch, and any FakeTensor arguments 780 # will be considered eligible. 781 unrecognized_types = [ 782 t for t in types if not issubclass(t, FakeTensor) and t is not Tensor 783 ] 784 if unrecognized_types: 785 not_implemented_log.debug( 786 "FakeTensor unrecognized subclass(es): %s", unrecognized_types 787 ) 788 return NotImplemented 789 790 fake_mode = None 791 for arg in pytree.arg_tree_leaves(*args, **kwargs): 792 if isinstance(arg, FakeTensor): 793 fake_mode = arg.fake_mode 794 break 795 796 assert fake_mode is not None 797 798 # If the fake mode is already active, don't try to reapply it! 799 # NotImplemented is the right thing to return here, because the 800 # typical situation this can occur is if ProxyTensorMode returned a 801 # NotImplemented because of a not implemented subclass; we may have 802 # unluckily attempted to hit FakeTensor's dispatch first, 803 # NotImplemented lets us keep chaining until we find the actual 804 # subclass 805 maybe_cur_fake_mode = torch._C._get_dispatch_mode( 806 torch._C._TorchDispatchModeKey.FAKE 807 ) 808 if maybe_cur_fake_mode: 809 not_implemented_log.debug( 810 "FakeTensor mode already active: %s in %s", 811 fake_mode, 812 maybe_cur_fake_mode, 813 ) 814 return NotImplemented 815 816 assert not fake_mode.in_kernel_invocation 817 818 with fake_mode: 819 return func(*args, **kwargs) 820 821 @staticmethod 822 def _find_common_device( 823 func: OpOverload, flat_args: Sequence[object] 824 ) -> Tuple[torch.device, bool]: 825 # Returns: (common_device, has_scalar_only_inputs) 826 827 # cpu - zero-dim tensors can be called in cuda kernels, 828 # so overwrite the common_device if it the only existing 829 # device comes from a cpu zero-dim tensor 830 common_device = None 831 has_scalar_only_inputs = False 832 is_cpu_zero_dim = None 833 834 def cpu_zero_dim(t: Tensor) -> bool: 835 return t.device.type == "cpu" and t.dim() == 0 836 837 def merge_devices(t: object) -> None: 838 nonlocal common_device 839 nonlocal is_cpu_zero_dim 840 if not isinstance(t, FakeTensor): 841 return 842 843 if common_device is None: 844 common_device = t.device 845 is_cpu_zero_dim = cpu_zero_dim(t) 846 return 847 848 t_is_cpu_zero_dim = cpu_zero_dim(t) 849 if t.device == common_device: 850 if is_cpu_zero_dim: 851 is_cpu_zero_dim = t_is_cpu_zero_dim 852 return 853 854 # mismatching devices ! 855 # if current tensor is cpu 0 dim, defer to existing device 856 if t_is_cpu_zero_dim: 857 return 858 859 # current device is from cpu 0 dim tensor, overwrite 860 if is_cpu_zero_dim: 861 common_device = t.device 862 is_cpu_zero_dim = t_is_cpu_zero_dim 863 return 864 865 # mismatching devices of non-zero dim tensors, throw 866 # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as 867 raise RuntimeError( 868 f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}" 869 ) 870 871 for arg in flat_args: 872 merge_devices(arg) 873 874 # some functions that allow Python numbers to bind to Tensors 875 # if we have failed to find a device, and we're running one of these operators, 876 # we must have scalar only inputs 877 if should_allow_numbers_as_tensors(func) and common_device is None: 878 # ops with scalar only inputs always have result on cpu 879 has_scalar_only_inputs = True 880 common_device = torch.device("cpu") 881 882 assert common_device is not None, f"Could not find common device for {func}" 883 884 return common_device, has_scalar_only_inputs 885 886 def get_nested_int( 887 self, 888 *, 889 coeff: Union[int, torch.SymInt] = 1, 890 ) -> torch.SymInt: 891 if self.nested_int_memo is None: 892 self.nested_int_memo = self.fake_mode.create_symbolic_nested_int( 893 nt_tensor_id=None 894 ) 895 return self.nested_int_memo * coeff 896 897 # Similar to FunctionalTensor.tolist 898 def tolist(self) -> Any: 899 if self.dim() == 0: 900 return self.item() 901 elif self.dim() == 1: 902 return [elem.item() for elem in self] 903 else: 904 return [elem.tolist() for elem in self] 905 906 907_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"] 908 909 910@dataclass_slots 911@dataclass 912class TensorMetadata: 913 """ 914 The Tensor metadata relevant to hashing FakeTensors when caching. 915 """ 916 917 dtype: torch.dtype 918 shape: Tuple[_MetadataIntLike, ...] 919 stride: Tuple[_MetadataIntLike, ...] 920 device: torch.device 921 layout: torch.layout 922 memory_format: Optional[torch.memory_format] 923 storage_offset: _MetadataIntLike 924 storage_bytes: Optional[_MetadataIntLike] 925 requires_grad: bool 926 is_quantized: bool 927 is_conj: bool 928 is_neg: bool 929 is_inference: bool 930 is_sparse: bool # read: is sparse COO 931 is_coalesced: Optional[bool] 932 dense_dim: Optional[int] 933 sparse_dim: Optional[int] 934 935 def _flatten_into( 936 self, 937 result: List[object], 938 mode: FakeTensorMode, 939 state: _CacheKeyState, 940 ) -> None: 941 # Flatten the TensorMetadata out into `result`. Make sure to call 942 # state.convert_sym_int() on any SymInts. 943 for field in dataclasses.fields(self): 944 value = getattr(self, field.name) 945 if isinstance(value, (tuple, list, torch.Size)): 946 # This will recursively flatten the iterable, calling 947 # convert_sym_int() as necessary. 948 mode._prep_args_for_hash(result, value, state) 949 elif isinstance(value, SymInt): 950 state.convert_sym_int(result, value) 951 else: 952 result.append(value) 953 954 955def extract_tensor_metadata(t: Tensor) -> TensorMetadata: 956 """ 957 Extract the TensorMetadata of a tensor. 958 """ 959 memory_format: Optional[torch.memory_format] = suggest_memory_format(t) 960 # Don't call is_contiguous() on a Tensor which has symbolic sizes or things 961 # will go badly (guards will be messed up?) 962 if ( 963 t._has_symbolic_sizes_strides 964 or is_sparse_any(t) 965 or not t.is_contiguous(memory_format=memory_format) 966 ): 967 memory_format = None 968 969 storage_offset = t.storage_offset() 970 971 return TensorMetadata( 972 t.dtype, 973 t.shape, 974 t.stride() if t.layout == torch.strided else (), 975 t.device, 976 t.layout, 977 memory_format, 978 storage_offset, 979 # Only set storage_bytes for tensors that have storage (not sparse) 980 t.untyped_storage().nbytes() if not is_sparse_any(t) else None, 981 t.requires_grad, 982 t.is_quantized, 983 t.is_conj(), 984 t.is_neg(), 985 t.is_inference(), 986 t.is_sparse, 987 t.is_coalesced() if t.is_sparse else None, 988 t.dense_dim() if is_sparse_any(t) else None, 989 t.sparse_dim() if is_sparse_any(t) else None, 990 ) 991 992 993@dataclass_slots 994@dataclass 995class _DispatchCacheKey: 996 """ 997 Key for the FakeTensor dispatch cache. 998 """ 999 1000 key: Tuple[object, ...] 1001 hashvalue: int 1002 1003 def __init__(self, tup: Tuple[object, ...]) -> None: 1004 self.key = tup 1005 self.hashvalue = hash(tup) 1006 1007 def __eq__(self, other: object) -> bool: 1008 return isinstance(other, _DispatchCacheKey) and self.key == other.key 1009 1010 def __hash__(self) -> int: 1011 return self.hashvalue 1012 1013 def strip_shape_env(self) -> None: 1014 # We need to strip the ShapeEnv from any values before we store in the 1015 # cache so the cache doesn't keep our ShapeEnvs alive. 1016 for v in self.key: 1017 if isinstance(v, _PySymInputStub): 1018 v.strip_shape_env() 1019 1020 1021@dataclass_slots 1022@dataclass(frozen=True) 1023class _DispatchCacheEntry: 1024 """ 1025 Entry type for the FakeTensor dispatch cache. Accounts for two possibilities: 1026 1) The op is inplace, and a hit means we need to alias the argument at a 1027 given index. 1028 2) We need to synthesize a new FakeTensor given tensor metadata. For view 1029 ops, we further capture the index of the arg to alias. 1030 """ 1031 1032 inplace_idx: Optional[int] 1033 metadata: Optional[TensorMetadata] 1034 view_idx: Optional[int] 1035 1036 1037@dataclass_slots 1038@dataclass(frozen=True) 1039class _BypassDispatchCache(Exception): 1040 """ 1041 Signals cases that should skip FakeTensor caching. 1042 """ 1043 1044 reason: str 1045 1046 1047@dataclass_slots 1048@dataclass(frozen=True) 1049class DispatchCacheInfo: 1050 """ 1051 Information about the state of the FakeTensor dispatch cache. 1052 """ 1053 1054 hits: int 1055 misses: int 1056 bypasses: Dict[str, int] 1057 size: int 1058 1059 1060# We keep one instantiation of `fake_tensor_converter` active 1061# for the duration of `with FakeTensorMode()`. 1062# This allows accurate storage aliasing across invocation of 1063# different operators. While this will keep all freshly allocated 1064# tensors alive during `FakeTensorMode`, there will no be no 1065# new allocations of Tensors which have non-meta storage so 1066# memory should not significantly increase. 1067 1068 1069class FakeTensorMode(TorchDispatchMode): 1070 cache: Dict[_DispatchCacheKey, _DispatchCacheEntry] = {} 1071 cache_hits: int = 0 1072 cache_misses: int = 0 1073 cache_bypasses: Dict[str, int] = defaultdict(int) 1074 # Every time you retrace using the same fake tensor mode, you should 1075 # advance the epoch so we don't reuse unbacked memos 1076 epoch: int = 0 1077 in_kernel_invocation: bool = False 1078 static_shapes: bool 1079 shape_env: Optional[ShapeEnv] 1080 _stack: Optional[str] 1081 allow_meta: bool 1082 1083 # NestedTensor uses a tensor_id_counter to uniquely identify offsets. 1084 # This counter is incremented when an offsets is used to create an NJT 1085 # for the first time. To avoid mutating eager state if we construct NJT 1086 # during tracing, we maintain a separate counter on the FakeTensorMode. 1087 # The initial count is set to the current eager tensor_id_counter value 1088 # upon initialization, and every time you retrace using the same fake tensor 1089 # mode, you should reset the counter to the initial count. 1090 nt_tensor_id_counter: int = -1 1091 nt_tensor_id_initial_count: int = -1 1092 1093 def __init__( 1094 self, 1095 *, 1096 allow_fallback_kernels: bool = True, 1097 allow_non_fake_inputs: bool = False, 1098 shape_env: Optional[ShapeEnv] = None, 1099 static_shapes: Optional[bool] = None, 1100 # TODO: This is a temporary measure, see 1101 # https://github.com/pytorch/pytorch/pull/126245#discussion_r1604185748 1102 # We're currently solely using this to impede population of 1103 # item_memo for 0d scalar tensor inputs when export, because this 1104 # causes things that used to be deferred runtime asserts to turn into 1105 # guards, and then the guards are just lost. We can potentially fix 1106 # this by ensuring guards also get put in the graph, but this is 1107 # pending a rework of how deferred runtime asserts in export. Once 1108 # that's done, we can remove this. 1109 export: bool = False, 1110 ) -> None: 1111 log.debug("create_mode 0x%x", id(self)) 1112 super().__init__() 1113 self.allow_fallback_kernels = allow_fallback_kernels 1114 1115 import torch._dynamo.config 1116 import torch._functorch.config 1117 1118 self.propagate_real_tensors = ( 1119 torch._functorch.config.fake_tensor_propagate_real_tensors 1120 ) 1121 self.fake_tensor_converter = FakeTensorConverter( 1122 copy_data=self.propagate_real_tensors, 1123 export=export, 1124 ) 1125 1126 if static_shapes is not None: 1127 self.static_shapes = static_shapes 1128 else: 1129 self.static_shapes = shape_env is None 1130 1131 # This is temporarily patched to True in Dynamo to grandfather in some 1132 # places where we unconditionally allow scalar outputs, TO BE REMOVED 1133 self.allow_scalar_outputs = False 1134 1135 self._allow_unsafe_data_ptr_access = ( 1136 torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access 1137 ) 1138 self.allow_meta = torch._functorch.config.fake_tensor_allow_meta 1139 self.cache_enabled = ( 1140 torch._dynamo.config.fake_tensor_cache_enabled 1141 and not self.propagate_real_tensors 1142 ) 1143 self.cache_crosscheck_enabled = ( 1144 torch._dynamo.config.fake_tensor_cache_crosscheck_enabled 1145 ) 1146 1147 # A flag that controls, whether we want to invoke ops on mix of 1148 # real weights/global variables and fake inputs 1149 self.allow_non_fake_inputs = allow_non_fake_inputs 1150 1151 # [in_kernel_invocation] 1152 # when FakeTensor is invoked in user code, .device should return 1153 # the fake_device of the tensor so that code such as as `if x.is_cuda` 1154 # or torch.zeros([10, 10], device=x.device) continues to execute as if 1155 # the FakeTensor were real. However, within kernel execution, we return 1156 # the `Meta` device because all computation within the kernels should 1157 # behave as if the Tensors are on meta devices. Kernels should allocate 1158 # new tensors on meta devices, and checks like `is_meta` should return true. 1159 # within python refs, we always return the real device by defining 1160 # the device property 1161 self.in_kernel_invocation = False 1162 1163 # True if we enter'ed and actually enabled fake tensor mode, 1164 # false if it was a no-op. Not thread safe but neither is 1165 # in_kernel_invocation 1166 # If another fake mode was already active when we enter, we also stash it here. 1167 # That way when we exit, we know to re-enable the previous fake mode. 1168 self.enter_stack: List[ 1169 Tuple[bool, Optional[TorchDispatchMode], Optional[bool]] 1170 ] = [] 1171 1172 self.shape_env = shape_env 1173 1174 self._stack_trace = traceback.extract_stack() 1175 self._stack = None 1176 1177 # Indicates to our torch_dispatch dispatching infra that 1178 # this is an "infra" mode with lower dispatching precedence. 1179 self._mode_key = torch._C._TorchDispatchModeKey.FAKE 1180 1181 import torch.nested._internal.nested_tensor 1182 1183 self.nt_tensor_id_initial_count = ( 1184 torch.nested._internal.nested_tensor._tensor_id_counter 1185 ) 1186 self.nt_tensor_id_counter = self.nt_tensor_id_initial_count 1187 1188 def reset_nt_tensor_id_counter(self) -> None: 1189 self.nt_tensor_id_counter = self.nt_tensor_id_initial_count 1190 1191 # Typically, there is only one fake tensor mode and you test for it by 1192 # doing an isinstance test. However, in some situations, there might be 1193 # TWO fake tensor modes. The canonical example of this is exporting 1194 # a fake model: there is an outer fake mode created by the user, and 1195 # an inner fake mode created by Dynamo. The two phase process is required 1196 # because the outer fake mode typically won't have a ShapeEnv, even if 1197 # the user is interested in exporting with dynamic shapes (so the inner 1198 # fake mode will actually have a ShapeEnv and swap in symbolic sizes.) 1199 # 1200 # In this case, it's insufficient to test only one FakeTensor: you need 1201 # to distinguish between our fake tensor and other fake tensors. That's 1202 # what this function does. 1203 def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]: 1204 return isinstance(t, FakeTensor) and t.fake_mode is self 1205 1206 # If we should avoid device init. This changes the behavior of various APIs: 1207 # - We avoid constant-prop on Tensors with ops that move them to another device 1208 # - We change the torch.tensor ctor contract to never materialize 1209 # tensors on device 1210 # (see NOTE: [torch.tensor, lift_fresh, and device movement]) 1211 @property 1212 def avoid_device_init(self) -> bool: 1213 if torch.xpu._is_compiled(): 1214 assert not torch.cuda._is_compiled() 1215 return not torch.xpu.is_available() 1216 1217 return not torch.cuda.is_available() 1218 1219 @property 1220 def stack(self) -> str: 1221 if self._stack is None: 1222 self._stack = "".join(traceback.format_list(self._stack_trace)) 1223 return self._stack 1224 1225 @count 1226 def __torch_dispatch__( 1227 self, 1228 func: OpOverload, 1229 types: Sequence[Type], 1230 args: Sequence[object] = (), 1231 kwargs: Mapping[str, object] = immutable_dict(), 1232 ) -> object: 1233 # FakeTensorMode should not be set when we're inside of it. 1234 assert ( 1235 torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None 1236 ), func 1237 try: 1238 return self.dispatch(func, types, args, kwargs) 1239 except TypeError: 1240 log.exception("fake tensor raised TypeError") 1241 raise 1242 1243 # No-op if FakeTensorMode is already in use 1244 def __enter__(self) -> Self: 1245 import torch.nested._internal.nested_tensor 1246 1247 prev_only_lift_cpu_tensors = None 1248 if self.avoid_device_init: 1249 # See NOTE: [torch.tensor, lift_fresh, and device movement] 1250 prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors() 1251 torch._C._set_only_lift_cpu_tensors(True) 1252 maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key) 1253 if self is not maybe_prev_fake_mode: 1254 self.enter_stack.append( 1255 (True, maybe_prev_fake_mode, prev_only_lift_cpu_tensors) 1256 ) 1257 return super().__enter__() 1258 else: 1259 # no-op (still need to re-set the fake mode though since we unset it) 1260 torch._C._set_dispatch_mode(self) 1261 self.enter_stack.append((False, None, prev_only_lift_cpu_tensors)) 1262 return self 1263 1264 def __exit__( 1265 self, 1266 a: Optional[Type[BaseException]], 1267 b: Optional[BaseException], 1268 c: Optional[TracebackType], 1269 ) -> None: 1270 ( 1271 live, 1272 maybe_prev_fake_mode, 1273 maybe_prev_only_lift_cpu_tensors, 1274 ) = self.enter_stack.pop() 1275 if live: 1276 out = super().__exit__(a, b, c) 1277 # Re-enable the previous fake mode, if there was one. 1278 if maybe_prev_fake_mode is not None: 1279 torch._C._set_dispatch_mode(maybe_prev_fake_mode) 1280 if maybe_prev_only_lift_cpu_tensors is not None: 1281 torch._C._set_only_lift_cpu_tensors(maybe_prev_only_lift_cpu_tensors) 1282 1283 @classmethod 1284 def is_infra_mode(cls) -> bool: 1285 return True 1286 1287 @classmethod 1288 def cache_info(cls) -> DispatchCacheInfo: 1289 """ 1290 Query the state of the dispatch cache. 1291 """ 1292 return DispatchCacheInfo( 1293 FakeTensorMode.cache_hits, 1294 FakeTensorMode.cache_misses, 1295 dict(FakeTensorMode.cache_bypasses), 1296 len(FakeTensorMode.cache), 1297 ) 1298 1299 @classmethod 1300 def cache_clear(cls) -> None: 1301 """ 1302 Clear the dispatch cache. 1303 """ 1304 cls.cache_hits = 0 1305 cls.cache_misses = 0 1306 cls.cache_bypasses.clear() 1307 cls.cache.clear() 1308 1309 def _cached_dispatch_impl( 1310 self, 1311 func: OpOverload, 1312 types: Sequence[Type], 1313 args: Sequence[object], 1314 kwargs: Mapping[str, object], 1315 ) -> object: 1316 """ 1317 Lookup a cache entry for the given arguments. If none exists, dispatch 1318 and cache the result (if the result is eligible for caching). 1319 """ 1320 output: object = _UNASSIGNED 1321 try: 1322 state = _CacheKeyState(self.shape_env) 1323 key = self._cache_key(state, func, args, kwargs) 1324 if state.cache_on_shape_env(): 1325 assert state.shape_env is not None 1326 cache = state.shape_env.fake_tensor_cache 1327 else: 1328 cache = FakeTensorMode.cache 1329 entry = cache.get(key, None) 1330 if entry is not None: 1331 output = self._output_from_cache_entry(state, entry, key, func, args) 1332 FakeTensorMode.cache_hits += 1 1333 if self.cache_crosscheck_enabled: 1334 # For debugging / testing: Validate that the output synthesized 1335 # from the cache matches the output created by normal dispatch. 1336 self._crosscheck_cache_output(output, func, types, args, kwargs) 1337 else: 1338 self._validate_cache_key(func, args, kwargs) 1339 output = self._dispatch_impl(func, types, args, kwargs) 1340 entry = self._make_cache_entry(state, key, func, args, kwargs, output) 1341 key.strip_shape_env() 1342 cache[key] = entry 1343 FakeTensorMode.cache_misses += 1 1344 except _BypassDispatchCache as e: 1345 FakeTensorMode.cache_bypasses[e.reason] += 1 1346 1347 if output is _UNASSIGNED: 1348 output = self._dispatch_impl(func, types, args, kwargs) 1349 1350 return output 1351 1352 def _cache_key( 1353 self, 1354 state: _CacheKeyState, 1355 func: OpOverload, 1356 args: Sequence[object], 1357 kwargs: Mapping[str, object], 1358 ) -> _DispatchCacheKey: 1359 """ 1360 Create a cache key given the dispatch args. Raises _BypassDispatchCache 1361 for any situation that precludes caching. 1362 """ 1363 key_values = [ 1364 func, 1365 # Capture the default_dtype mode since that can affect the output tensor, 1366 # e.g., when operating on constant float values. 1367 torch.get_default_dtype(), 1368 # Capture the current device to support, e.g., cache tensor creation, 1369 # where there isn't necessarily a tensor to take the device from. 1370 torch._C._get_default_device(), 1371 # We want to create tensors from cached metadata only when the inference 1372 # mode is the same. 1373 torch.is_inference_mode_enabled(), 1374 # Shape env settings could affect behavior. One example seen in the wild: 1375 # Disallowing dynamic shapes can introduce a DynamicOutputShapeException 1376 # where it wasn't seen on a previous instance of the same op. 1377 self.shape_env.settings if self.shape_env else None, 1378 ] 1379 # Translate any FakeTensor args to metadata. 1380 if args: 1381 self._prep_args_for_hash(key_values, args, state) 1382 if kwargs: 1383 self._prep_args_for_hash(key_values, kwargs, state) 1384 return _DispatchCacheKey(tuple(key_values)) 1385 1386 def _validate_cache_key( 1387 self, 1388 func: OpOverload, 1389 args: Sequence[object], 1390 kwargs: Mapping[str, object], 1391 ) -> None: 1392 """ 1393 Validate that the cache key generated by _cache_key will be 1394 reasonable. 1395 """ 1396 # Avoid caching for any ops that would require a more sophisticated 1397 # caching implementation, e.g., data dependent ops or ops that modify 1398 # the inputs. 1399 if torch.Tag.data_dependent_output in func.tags: 1400 raise _BypassDispatchCache("data dependent output") 1401 1402 if torch.Tag.dynamic_output_shape in func.tags: 1403 raise _BypassDispatchCache("dynamic output shape") 1404 1405 if torch.Tag.inplace_view in func.tags: 1406 raise _BypassDispatchCache("inplace view") 1407 1408 if func == aten._unsafe_view.default: 1409 raise _BypassDispatchCache("unsafe view") 1410 1411 if func in self.lift_fns: 1412 raise _BypassDispatchCache("lift") 1413 1414 if func.name() == "inductor::resize_storage_bytes_": 1415 raise _BypassDispatchCache("inductor::resize_storage_bytes_") 1416 1417 if not torch._library.utils.is_builtin(func): 1418 raise _BypassDispatchCache("non-builtin") 1419 1420 # In order to handle storage aliasing, we need to establish the alias 1421 # for any view op on a cache hit. But CompositeImplicitAutograd ops may 1422 # or may not alias the input, so just punt on caching these. 1423 if func.is_view and torch._C._dispatch_has_kernel_for_dispatch_key( 1424 func.name(), torch._C.DispatchKey.CompositeImplicitAutograd 1425 ): 1426 raise _BypassDispatchCache("CompositeImplicitAutograd") 1427 1428 def _prep_args_for_hash( 1429 self, 1430 result: List[object], 1431 args: Union[Mapping[str, object], Sequence[object], Iterable[object]], 1432 state: _CacheKeyState, 1433 ) -> None: 1434 """ 1435 Translate the provided args into a form suitable for caching at FakeTensor 1436 dispatch, i.e., convert unhashable types like lists & dicts into tuples and 1437 convert FakeTensors into metadata. Raises _BypassDispatchCache to signal 1438 unsupported cases that should bypass caching. 1439 """ 1440 if isinstance(args, dict): 1441 self._prep_args_for_hash(result, args.keys(), state) 1442 self._prep_args_for_hash(result, args.values(), state) 1443 return 1444 1445 for arg in args: 1446 if isinstance(arg, FakeTensor): 1447 if not self.is_our_fake(arg): 1448 raise _BypassDispatchCache("not our fake") 1449 if arg.constant is not None: 1450 raise _BypassDispatchCache("constant attribute") 1451 if is_sparse_any(arg): 1452 raise _BypassDispatchCache(f"{arg.layout} tensor") 1453 # FIXME: For now back out caching when there are symbolic nbytes 1454 # - this doesn't seem to play nice with set(). See T196779132 for examples. 1455 if isinstance(arg.untyped_storage().nbytes(), SymInt): 1456 raise _BypassDispatchCache("symbolic nbytes") 1457 metadata = extract_tensor_metadata(arg) 1458 metadata._flatten_into(result, self, state) 1459 elif isinstance(arg, Tensor): 1460 raise _BypassDispatchCache("non-fake tensor") 1461 elif isinstance(arg, SymInt): 1462 state.convert_sym_int(result, arg) 1463 elif isinstance(arg, (SymBool, SymFloat)): 1464 raise _BypassDispatchCache("symbolic shape") 1465 elif isinstance(arg, (list, tuple, dict)): 1466 self._prep_args_for_hash(result, arg, state) 1467 else: 1468 # It's important to capture the type of the arg since, e.g., 1 and 1.0 1469 # hash to the same value, but can produce different dtypes for the 1470 # output tensor. 1471 result.append(type(arg)) 1472 result.append(arg) 1473 1474 def _make_cache_entry( 1475 self, 1476 state: _CacheKeyState, 1477 key: _DispatchCacheKey, 1478 func: OpOverload, 1479 args: Sequence[object], 1480 kwargs: Mapping[str, object], 1481 output: Optional[FakeTensor], 1482 ) -> _DispatchCacheEntry: 1483 """ 1484 Make a cache entry object for the given 'output' Tensor. Raises 1485 _BypassDispatchCache if the output tensor has characteristics that 1486 prevent caching it. 1487 """ 1488 if output is None: 1489 return _DispatchCacheEntry(inplace_idx=None, metadata=None, view_idx=None) 1490 1491 # Some ops return tuples of Tensors, but it's rare, so avoid 1492 # the complexity of caching other types. 1493 if not isinstance(output, FakeTensor): 1494 raise _BypassDispatchCache("non-FakeTensor output") 1495 1496 # Avoid caching FakeTensors with constants attached since those 1497 # can be invalidated. 1498 if output.constant is not None: 1499 raise _BypassDispatchCache("constant attribute") 1500 1501 # TODO: support caching sparse outputs? 1502 if output.is_sparse: 1503 raise _BypassDispatchCache("sparse output") 1504 1505 if is_sparse_compressed(output): 1506 raise _BypassDispatchCache("sparse compressed output") 1507 1508 # Can an in-place op really reference a kwarg? If so, then we need 1509 # to extend the implementation to handle it. 1510 for kval in kwargs.values(): 1511 if id(kval) == id(output): 1512 raise _BypassDispatchCache("kwarg aliases output") 1513 1514 # If this is an in-place op, the entry records which input arg is aliased. 1515 for idx in range(len(args)): 1516 if id(args[idx]) == id(output): 1517 return _DispatchCacheEntry( 1518 inplace_idx=idx, metadata=None, view_idx=None 1519 ) 1520 1521 # Otherwise, create an entry that records the output tensor's metadata. 1522 view_idx = None 1523 if func.is_view: 1524 idxs = [i for i, t in enumerate(args) if isinstance(t, Tensor)] 1525 assert len(idxs) == 1 1526 view_idx = idxs[0] 1527 1528 metadata = extract_tensor_metadata(output) 1529 metadata.shape = tuple(state.convert_output(v) for v in metadata.shape) 1530 metadata.stride = tuple(state.convert_output(v) for v in metadata.stride) 1531 metadata.storage_offset = state.convert_output(metadata.storage_offset) 1532 metadata.storage_bytes = ( 1533 None 1534 if metadata.storage_bytes is None 1535 else state.convert_output(metadata.storage_bytes) 1536 ) 1537 1538 entry = _DispatchCacheEntry( 1539 inplace_idx=None, 1540 metadata=metadata, 1541 view_idx=view_idx, 1542 ) 1543 1544 # N.B.: Some checks for bypassing the cache would be performed on the 1545 # output tensor synthesized from the cached metadata. As an optimization, 1546 # we can synthesize a tensor here and do the checks on that instance. 1547 # This approach keeps the (more frequent) cache-hit path as lightweight 1548 # as possible. 1549 synth_output = self._output_from_cache_entry(state, entry, key, func, args) 1550 1551 # Make sure the dispatch_key_set from the synthesized output tensor will 1552 # be the same. 1553 synth_key_set = torch._C._dispatch_key_set(synth_output) 1554 key_set = torch._C._dispatch_key_set(output) 1555 if synth_key_set != key_set: 1556 raise _BypassDispatchCache("dispatch_key_set mismatch") 1557 1558 return entry 1559 1560 def _output_from_cache_entry( 1561 self, 1562 state: _CacheKeyState, 1563 entry: _DispatchCacheEntry, 1564 key: _DispatchCacheKey, 1565 func: OpOverload, 1566 args: Sequence[object], 1567 ) -> Optional[FakeTensor]: 1568 """ 1569 Create a new FakeTensor from the cache entry. 1570 """ 1571 if entry.inplace_idx is not None: 1572 # This is an in-place op; return the aliased arg. 1573 inplace_arg = args[entry.inplace_idx] 1574 assert isinstance(inplace_arg, FakeTensor) 1575 return inplace_arg 1576 1577 # Synthesize a new FakeTensor with the cached metadata. 1578 metadata = entry.metadata 1579 if metadata is None: 1580 return None 1581 1582 assert not is_sparse_any(metadata) 1583 1584 def check_value( 1585 value: _MetadataIntLike, state: _CacheKeyState 1586 ) -> Union[IntLikeType]: 1587 if isinstance(value, _SymIntOutputStub): 1588 assert state.shape_env is not None 1589 return value.extract(key, state.shape_env) 1590 else: 1591 assert not isinstance(value, _PySymInputStub) 1592 return value 1593 1594 shape = tuple(check_value(v, state) for v in metadata.shape) 1595 stride = tuple(check_value(v, state) for v in metadata.stride) 1596 storage_offset = check_value(metadata.storage_offset, state) 1597 storage_bytes = ( 1598 None 1599 if metadata.storage_bytes is None 1600 else check_value(metadata.storage_bytes, state) 1601 ) 1602 1603 maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext 1604 if self.shape_env is not None: 1605 maybe_suppress = self.shape_env.suppress_guards 1606 1607 with in_kernel_invocation_manager(self), maybe_suppress(): 1608 empty = torch.empty_strided( 1609 shape, 1610 stride, 1611 dtype=metadata.dtype, 1612 layout=metadata.layout, 1613 device="meta", 1614 requires_grad=metadata.requires_grad, 1615 ) 1616 1617 if metadata.is_conj: 1618 torch._C._set_conj(empty, True) 1619 if metadata.is_neg: 1620 torch._C._set_neg(empty, True) 1621 1622 if func.is_view: 1623 # For view ops, the storage should be the same as the tensor input. 1624 view_arg = args[cast(int, entry.view_idx)] 1625 assert isinstance(view_arg, FakeTensor) 1626 storage = view_arg.untyped_storage() 1627 with in_kernel_invocation_manager(self), maybe_suppress(): 1628 empty.set_(storage, storage_offset, shape, stride) 1629 1630 return FakeTensor(self, empty, metadata.device) 1631 1632 def _crosscheck_cache_output( 1633 self, 1634 output: Optional[FakeTensor], 1635 func: OpOverload, 1636 types: Sequence[Type], 1637 args: Sequence[object], 1638 kwargs: Mapping[str, object], 1639 ) -> None: 1640 """ 1641 Helper to validate that the output synthesized from the cache matches 1642 the output created by normal dispatch. 1643 """ 1644 try: 1645 true_output = self._dispatch_impl(func, types, args, kwargs) 1646 except Exception as e: 1647 raise RuntimeError( 1648 f"FakeTensor cache crosscheck failure: func={func}, " 1649 f"args={args}, kwargs={kwargs}: Dispatch raised={e}" 1650 ) from e 1651 try: 1652 if (true_output is not None) and (output is not None): 1653 assert_metadata_eq(assert_eq, true_output, output) 1654 else: 1655 assert true_output is None 1656 assert output is None 1657 except Exception as e: 1658 raise RuntimeError( 1659 f"FakeTensor cache crosscheck failure: func={func}, " 1660 f"args={args}, kwargs={kwargs}" 1661 ) from e 1662 1663 def dispatch( 1664 self, 1665 func: OpOverload, 1666 types: Sequence[Type], 1667 args: Sequence[object] = (), 1668 kwargs: Mapping[str, object] = immutable_dict(), 1669 ) -> object: 1670 kwargs = kwargs or {} 1671 with no_dispatch(): 1672 log.debug("%s %s %s", func, args, kwargs) 1673 1674 if func in _DISPATCH_META_HANDLERS: 1675 return _DISPATCH_META_HANDLERS[func](args) 1676 1677 if log.getEffectiveLevel() <= logging.DEBUG: 1678 log.debug( 1679 "%sFakeTensorMode.__torch_dispatch__: %s", " " * RECURSION_COUNT, func 1680 ) 1681 # NOTE: incr is intentionally unused for a RAII pattern 1682 incr = IncrementRecursionCount() 1683 1684 # Some attribute queries that can be serviced directly 1685 # See Note [is_coalesced is dispatched] 1686 if func in _DISPATCH_HANDLE_DIRECTLY: 1687 # NB: no_dispatch is ok here too, this func is very simple 1688 with in_kernel_invocation_manager(self): 1689 return func(*args, **kwargs) 1690 1691 if self.cache_enabled: 1692 return self._cached_dispatch_impl(func, types, args, kwargs) 1693 else: 1694 return self._dispatch_impl(func, types, args, kwargs) 1695 1696 def _dispatch_impl( 1697 self, 1698 func: OpOverload, 1699 types: Sequence[Type], 1700 args: Sequence[object], 1701 kwargs: Mapping[str, object], 1702 ) -> Optional[FakeTensor]: 1703 flat_args, args_spec = pytree.tree_flatten((args, kwargs)) 1704 1705 # DO NOT PUT LOGIC BEFORE UNRECOGNIZED TYPE CHECKING 1706 # We must throw NotImplemented in case of unrecognized types to handle subclasses. 1707 # Throwing the exception will pass the control to the next __torch_dispatch__. 1708 # See [subclass inputs] below 1709 # NB: If you're seeing a mysterious infinite loop involving fake 1710 # tensor, it might be related to this line. Though I'm not sure 1711 # how you'll know to read this comment, as this line won't show up 1712 # in the stack trace. 1713 has_unrecognized_types = _check_for_subclass(flat_args) 1714 if has_unrecognized_types: 1715 unrecognized_types = [ 1716 type(x) for x in flat_args if _check_for_subclass_arg(x) 1717 ] 1718 not_implemented_log.debug( 1719 "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types 1720 ) 1721 return NotImplemented 1722 1723 flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)] 1724 has_symbolic_sizes = any( 1725 i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors 1726 ) or any(isinstance(a, SymInt) for a in flat_args) 1727 1728 converter = self.fake_tensor_converter 1729 1730 is_lift_func = func in self.lift_fns 1731 1732 # To constant propagate through these functions: 1733 # 1, If this is a lift due to a torch.tensor call, 1734 # the input tensor is guaranteed to be a 1735 # constant, so we keep a copy of the original argument along so 1736 # we can query it if we're asked to item() it at some later point. 1737 # (Note that you can always call a lift fn manually, so we do 1738 # have to check if there are any fake tensors!) 1739 # 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div 1740 if (is_lift_func and not flat_arg_fake_tensors) or ( 1741 should_allow_numbers_as_tensors(func) 1742 and not has_symbolic_sizes 1743 and not flat_arg_fake_tensors 1744 ): 1745 assert all( 1746 t.constant is not None for t in flat_arg_fake_tensors 1747 ), f"{func} should not have fake inputs without constants" 1748 const_flat_args = [ 1749 a.constant if self.is_our_fake(a) else a for a in flat_args 1750 ] 1751 const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) 1752 out = func(*const_args, **const_kwargs) 1753 if type(out) is Tensor and self.may_turn_const(out): 1754 # NB: not in_kernel_invocation_manager because we're doing real 1755 # compute here 1756 # NB: no_dispatch() here is VERY DANGEROUS (like, segfault 1757 # dangerous) if this is actually a wrapper subclass tensor, 1758 # therefore the exact type test above 1759 with no_dispatch(): 1760 out = out.clone() 1761 return converter.from_real_tensor(self, out, make_constant=True) 1762 1763 # if we are in the dispatch mode, we will enter this function even if the inputs 1764 # are not FakeTensors. For now, throw if any non-Fake Tensor inputs 1765 # and just support constructors. 1766 1767 # this is generated from torch.tensor(), which does not use the 1768 # dispatcher, to allow wrapper subclasses to wrap the new tensor 1769 if is_lift_func: 1770 assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}" 1771 1772 if type(args[0]) is Tensor: 1773 return converter.from_real_tensor(self, args[0]) 1774 1775 # If we are trying to avoid device init, then we need to avoid constant 1776 # prop on constant tensors for ops that change devices. 1777 avoiding_device_init = False 1778 if self.avoid_device_init: 1779 if ( 1780 func == torch.ops.aten._to_copy.default 1781 and "device" in kwargs 1782 and kwargs["device"] != "cpu" 1783 ): 1784 avoiding_device_init = True 1785 if func == torch.ops.prims.device_put.default: 1786 avoiding_device_init = True 1787 1788 # Recompute flat_arg_fake_tensors here again in case some of the inputs 1789 # were real tensors and fakified in validate_and_convert_non_fake_tensors 1790 (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors( 1791 func, converter, flat_args, args_spec 1792 ) 1793 del args, kwargs # Invalidated 1794 1795 # The current constant handling only support tracing systems 1796 # (aot autograd, torchdynamo) where each operation is run consecutively. 1797 # Because each operation is run in order, we can trace out and support 1798 # sequences like: x = torch.tensor(0.); y = x.add_(1) 1799 # Whenver a constant is written to but with inputs that cannot be evaluated 1800 # statically, such as random_(), we invalidate all constants that alias the input 1801 # We will rely on functionalization for use of fake tensors constants as persistent 1802 # objects on an FX Graph. 1803 1804 # We dispatch size/stride/numel on the FakeTensor not its constant, so bail on inplace_view 1805 all_constant = all(e.constant is not None for e in flat_arg_fake_tensors) 1806 if ( 1807 torch.Tag.nondeterministic_seeded not in func.tags 1808 and torch.Tag.inplace_view not in func.tags 1809 and all_constant 1810 and len(flat_arg_fake_tensors) != 0 1811 and not has_symbolic_sizes 1812 and not avoiding_device_init 1813 ): 1814 const_flat_args = [ 1815 a.constant if self.is_our_fake(a) else a for a in flat_args 1816 ] 1817 const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) 1818 1819 # NB: not in_kernel_invocation_manager(self) as we want to do REAL 1820 # compute 1821 with no_dispatch(): 1822 out = func(*const_args, **const_kwargs) 1823 1824 flat_out = pytree.tree_leaves(out) 1825 flat_out_tensors = [t for t in flat_out if isinstance(t, Tensor)] 1826 all_constant = all(self.may_turn_const(t) for t in flat_out_tensors) 1827 1828 if all_constant: 1829 return pytree.tree_map_only( 1830 Tensor, 1831 lambda t: converter.from_real_tensor(self, t, make_constant=True), 1832 out, 1833 ) 1834 1835 # we weren't able to turn outputs to constants, 1836 # so invalidate all constants that might be aliases of the outputs 1837 for ten in flat_out_tensors: 1838 converter.invalidate_constant_aliases(ten) 1839 1840 # we are falling through to running non constant tensors, any input constant that 1841 # is written to must be invalidated 1842 args, kwargs = pytree.tree_unflatten(flat_args, args_spec) 1843 self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) 1844 1845 def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]: 1846 if isinstance(t, FakeTensor): 1847 return t.real_tensor 1848 elif isinstance(t, py_sym_types): 1849 assert self.shape_env is not None 1850 return t.node.pytype( 1851 t.node.expr.xreplace(self.shape_env.var_to_val).xreplace( 1852 self.shape_env.unbacked_var_to_val 1853 ) 1854 ) 1855 else: 1856 return t 1857 1858 from torch.fx.experimental.symbolic_shapes import ( 1859 compute_unbacked_bindings, 1860 free_unbacked_symbols, 1861 ) 1862 1863 nil = object() 1864 1865 real_out = nil 1866 if ( 1867 self.propagate_real_tensors 1868 and all(e.real_tensor is not None for e in flat_arg_fake_tensors) 1869 # TODO: Handle SymFloat/SymBool 1870 and not any( 1871 ( 1872 isinstance(a, SymInt) 1873 and (syms := free_unbacked_symbols(a)) 1874 and self.shape_env is not None 1875 and any(s not in self.shape_env.unbacked_var_to_val for s in syms) 1876 ) 1877 for a in flat_args 1878 ) 1879 ): 1880 real_flat_args = [maybe_to_real_tensor(a) for a in flat_args] 1881 real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec) 1882 real_out = func(*real_args, **real_kwargs) 1883 elif self.propagate_real_tensors: 1884 # This can happen occasionally legitimately, specifically when you 1885 # are inside the meta of a data dependent operation and you create 1886 # a tensor on an unbacked SymInt; at this point in time we don't 1887 # know what the unbacked SymInt is, but we will know later. 1888 # However, if there's a bug in the condition above, this condition 1889 # will also trigger. 1890 log.debug( 1891 "propagate_real_tensors skipped %s(%s, %s) %s", 1892 func, 1893 flat_arg_fake_tensors, 1894 flat_args, 1895 self.shape_env.unbacked_var_to_val if self.shape_env else None, 1896 ) 1897 1898 def maybe_propagate_real_tensors(fake_out: T) -> T: 1899 import sympy 1900 1901 def go(t: object, real_t: Tensor) -> None: 1902 if isinstance(t, FakeTensor): 1903 # NB: unconditionally overwrite 1904 t.real_tensor = real_t 1905 elif isinstance(t, py_sym_types) and free_unbacked_symbols(t): 1906 if isinstance(t.node.expr, sympy.Symbol): 1907 assert self.shape_env is not None 1908 self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t) 1909 1910 if real_out is not nil: 1911 tree_map_(go, fake_out, real_out) 1912 1913 # If a data-dependent op is used in a decomposition, we 1914 # may need to get the unbacked settings "early" 1915 # TODO: Is this really needed? 1916 compute_unbacked_bindings(self.shape_env, fake_out, peek=True) 1917 1918 return fake_out 1919 1920 # Try for fastpath 1921 if has_symbolic_sizes: 1922 fast_impl = get_fast_op_impls().get(func) 1923 if fast_impl is not None: 1924 return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs)) 1925 1926 # If there's a Python meta, prefer that over the decomposition 1927 from torch._decomp import meta_table as meta_table 1928 1929 if func not in meta_table and not self.cpp_meta_supports_symint(func): 1930 from torch._decomp import decomposition_table 1931 1932 # Prefer Python decompositions over C++ ones 1933 if func in decomposition_table and ( 1934 has_symbolic_sizes 1935 or ( 1936 # TODO: Remove these exclusions, so that we can remove 1937 # this leg entirely 1938 torch_decomp_decompositions(func) 1939 and all(not is_sparse_any(e) for e in flat_arg_fake_tensors) 1940 ) 1941 ): 1942 with self: 1943 return decomposition_table[func](*args, **kwargs) 1944 1945 with self: 1946 # Decomposes CompositeImplicitAutograd ops 1947 r = func.decompose(*args, **kwargs) 1948 if r is not NotImplemented: 1949 return r 1950 1951 # prims already wrap FakeTensor inputs to FakeTensor outputs 1952 # and do device logic, we dont need do anything but run them 1953 # and ensure that Meta kernels are dispatched to (see) 1954 # Fake Tensor Dispatch Keys 1955 # TODO - we should be use the prim aten impl 1956 # TODO - fix prims complex ops 1957 if ( 1958 "prims::" in func._schema.name 1959 and hasattr(func, "prim_meta_impl") 1960 and not stride_incorrect_op(func) 1961 ): 1962 with self: 1963 return maybe_propagate_real_tensors( 1964 func.prim_meta_impl(*args, **kwargs) 1965 ) 1966 1967 # Users can register FakeTensor rules for custom operators 1968 # Call them if they exist. 1969 maybe_fake_impl = torch._library.simple_registry.singleton.find( 1970 func.name() 1971 ).fake_impl.kernel 1972 if maybe_fake_impl: 1973 ctx = torch._library.fake_impl.FakeImplCtx(self, func) 1974 with torch._library.fake_impl.set_ctx_getter(lambda: ctx), self: 1975 result = maybe_fake_impl(*args, **kwargs) 1976 return maybe_propagate_real_tensors(result) 1977 1978 # special handling for funcs registered through `register_op_impl`, 1979 # e.g., manipulating args on constructor calls to construct meta tensors 1980 # and then afterwards wrapping them to a FakeTensor 1981 for run_impl_check, op_impl in op_implementations_checks: 1982 if run_impl_check(func): 1983 op_impl_out = op_impl(self, func, *args, **kwargs) 1984 if op_impl_out is not NotImplemented: 1985 return maybe_propagate_real_tensors(op_impl_out) 1986 1987 def maybe_run_unsafe_fallback( 1988 error: Optional[RuntimeError] = None, 1989 ) -> Optional[FakeTensor]: 1990 # We infer the meta of a custom ops that return None to just 1991 # return None. custom ops are not allowed to mutate metadata 1992 # of their inputs, so this is safe. 1993 if torch._library.utils.can_generate_trivial_fake_impl(func): 1994 return None 1995 # no meta kernel registered, fallback to kernel for the device 1996 if has_symbolic_sizes or not self.can_run_unsafe_fallback(func): 1997 raise UnsupportedOperatorException(func) 1998 if error is None: 1999 error = UnsupportedOperatorException(func) 2000 return run_fallback_kernel(self, func, flat_args, args_spec, error) 2001 2002 # Optimization: If there is no Meta kernel, it takes a surprisingly long 2003 # amount of time to catch the NotImplementedError, so we check it here. 2004 if not has_meta(func): 2005 fallback = maybe_run_unsafe_fallback() 2006 return maybe_propagate_real_tensors(fallback) 2007 2008 # run kernel registered to meta for func, which include 2009 # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) 2010 # It's possible that the kernel will return NotImplementedError 2011 try: 2012 with in_kernel_invocation_manager(self): 2013 r = func(*args, **kwargs) 2014 except NotImplementedError as not_implemented_error: 2015 return maybe_run_unsafe_fallback(not_implemented_error) 2016 except Exception: 2017 log.exception("failed while attempting to run meta for %s", func) 2018 raise 2019 2020 return maybe_propagate_real_tensors( 2021 self.wrap_meta_outputs_with_default_device_logic( 2022 r, func, flat_args, device=kwargs.get("device") 2023 ) 2024 ) 2025 2026 # WARNING: DO NOT add any additional namespaces/operators here if they refer to operators 2027 # outside of the pytorch/pytorch library! Any pre-existing things here 2028 # are either in the pytorch/pytorch library or have been grandfathered in. 2029 # The fallback does not always work and MAY CRASH and emit unreadable error messages 2030 # so it should not be allowed by default. 2031 _can_run_unsafe_fallback_allowed_namespaces = ordered_set( 2032 "debugprims", 2033 "prims", 2034 "aten", 2035 "xla", 2036 "vision", 2037 "torchtext", 2038 "torchaudio", 2039 "quantized", 2040 ) 2041 2042 def can_run_unsafe_fallback(self, func: OpOverload) -> bool: 2043 if not self.allow_fallback_kernels: 2044 return False 2045 # It's OK to try the fallback for built-in ops (e.g. aten, prims) 2046 # because we control and test these but the fallback leads to unexpected behavior 2047 # in user-defined custom ops 2048 return ( 2049 func.namespace in self._can_run_unsafe_fallback_allowed_namespaces 2050 or func.name() == "fbgemm::gmm" 2051 ) 2052 2053 def validate_and_convert_non_fake_tensors( 2054 self, 2055 func: OpOverload, 2056 converter: FakeTensorConverter, 2057 flat_args: Sequence[object], 2058 args_spec: TreeSpec, 2059 ) -> Tuple[List[object], List[FakeTensor]]: 2060 """ 2061 Checks if the list of tensors are fake tensors. 2062 If not, try to convert them to fake tensors. 2063 Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors. 2064 """ 2065 flat_arg_fake_tensors: List[FakeTensor] = [] 2066 2067 def validate(x: T) -> Union[T, FakeTensor]: 2068 if not isinstance(x, Tensor): 2069 return x 2070 2071 nonlocal flat_arg_fake_tensors 2072 if not self.is_our_fake(x): 2073 if torch.Tag.inplace_view in func.tags: 2074 args, kwargs = pytree.tree_unflatten(flat_args, args_spec) 2075 raise AssertionError( 2076 f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {render_call(func, args, kwargs)}" 2077 ) 2078 if not self.allow_non_fake_inputs: 2079 if isinstance(x, FakeTensor) and x.fake_mode is not self: 2080 raise AssertionError("Mixing fake modes NYI") 2081 args, kwargs = pytree.tree_unflatten(flat_args, args_spec) 2082 raise AssertionError( 2083 f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode " 2084 f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}" 2085 ) 2086 2087 out = converter.from_real_tensor(self, x) 2088 else: 2089 out = x 2090 2091 flat_arg_fake_tensors.append(out) 2092 return out 2093 2094 validated_args = [validate(a) for a in flat_args] 2095 return validated_args, flat_arg_fake_tensors 2096 2097 def wrap_meta_outputs_with_default_device_logic( 2098 self, 2099 r: object, 2100 func: OpOverload, 2101 flat_args: Sequence[object], 2102 device: torch.device, 2103 ) -> PyTree: 2104 converter = self.fake_tensor_converter 2105 2106 # Lazily initialized, in case there are no tensor returns 2107 common_device = None 2108 has_scalar_only_inputs = False 2109 2110 def wrap(e: T) -> Union[T, FakeTensor]: 2111 nonlocal common_device 2112 nonlocal has_scalar_only_inputs 2113 2114 if not isinstance(e, Tensor): 2115 return e 2116 2117 if common_device is None: 2118 ( 2119 common_device, 2120 has_scalar_only_inputs, 2121 ) = FakeTensor._find_common_device(func, flat_args) 2122 2123 is_our_fake = self.is_our_fake(e) 2124 if is_our_fake: 2125 torch._check( 2126 e.device == common_device, 2127 lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}", 2128 ) 2129 return cast(T, e) 2130 elif converter is not None: 2131 if has_scalar_only_inputs: 2132 # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div, 2133 # returns a real scalar tensor on CPU. See TensorMeta() in _prims/__init__.py for details. 2134 # We thus directly convert real tensor to fake tensor. 2135 return converter.from_real_tensor(self, e) 2136 else: 2137 return converter.from_meta_and_device( 2138 self, e, device or common_device 2139 ) 2140 else: 2141 return e 2142 2143 return tree_map(wrap, r) 2144 2145 def create_symbolic_nested_int( 2146 self, *, nt_tensor_id: Optional[int] = None 2147 ) -> torch.SymInt: 2148 # See Note: [Creating symbolic nested int] 2149 # Returned nested int always has coeff=1; multiply the result by coeff if needed 2150 import torch.nested._internal.nested_tensor 2151 2152 if nt_tensor_id is None: 2153 nt_tensor_id = self.nt_tensor_id_counter 2154 assert self.enter_stack, "should only called while FakeTensorMode is active" 2155 self.nt_tensor_id_counter += 1 2156 hint = torch._C._get_nested_int(nt_tensor_id, 1) 2157 2158 src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths") 2159 assert self.shape_env is not None 2160 ret = self.shape_env.create_symintnode( 2161 sym=self.shape_env.create_symbol( 2162 val=hint, 2163 source=src, 2164 ), 2165 hint=hint, 2166 source=src, 2167 ) 2168 return ret 2169 2170 _cpp_meta_supports_symint = ordered_set( 2171 aten.empty.memory_format, 2172 aten.empty_strided.default, 2173 aten.as_strided_scatter.default, 2174 aten.as_strided.default, 2175 aten.as_strided_.default, 2176 aten.zeros.default, 2177 aten.detach.default, 2178 aten.view_as_real.default, 2179 aten.view_as_complex.default, 2180 aten.set_.source_Storage_storage_offset, 2181 aten._sparse_coo_tensor_with_dims_and_tensors.default, 2182 ) 2183 2184 def cpp_meta_supports_symint(self, func: OpOverload) -> bool: 2185 if torch.Tag.view_copy in func.tags: 2186 return True 2187 return func in self._cpp_meta_supports_symint 2188 2189 lift_fns = ordered_set(aten.lift_fresh.default, aten.lift_fresh_copy.default) 2190 2191 def may_turn_const(self, t: Tensor) -> bool: 2192 return ( 2193 t.numel() <= CONSTANT_NUMEL_LIMIT 2194 and not is_sparse_any(t) 2195 and not self.is_our_fake(t) 2196 and not t.device.type == "meta" 2197 ) 2198 2199 def invalidate_written_to_constants( 2200 self, 2201 func: OpOverload, 2202 flat_arg_fake_tensors: Sequence[FakeTensor], 2203 args: Sequence[object], 2204 kwargs: Mapping[str, object], 2205 ) -> None: 2206 any_constant = any(e.constant is not None for e in flat_arg_fake_tensors) 2207 schema_info = get_schema_info(func) 2208 if any_constant and schema_info.is_mutable(): 2209 _, new_kwargs = normalize_function( # type: ignore[misc] 2210 func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] 2211 ) 2212 for k, v in new_kwargs.items(): 2213 k = k if (k != "input" or schema_info.has_argument(k)) else "self" 2214 if ( 2215 self.is_our_fake(v) 2216 and schema_info.is_mutable(k) 2217 and v.constant is not None 2218 ): 2219 self.fake_tensor_converter.invalidate_constant_aliases(v.constant) 2220 2221 def from_tensor( 2222 self, 2223 tensor: Tensor, 2224 *, 2225 static_shapes: Optional[bool] = None, 2226 source: Optional[Source] = None, 2227 symbolic_context: Optional[SymbolicContext] = None, 2228 trace: bool = True, 2229 ) -> FakeTensor: 2230 shape_env: Optional[ShapeEnv] = self.shape_env 2231 if static_shapes is None: 2232 static_shapes = self.static_shapes 2233 if static_shapes: 2234 assert ( 2235 symbolic_context is None 2236 ), "cannot set both static_shapes and symbolic_context" 2237 shape_env = None 2238 return self.fake_tensor_converter.from_real_tensor( 2239 self, 2240 tensor, 2241 shape_env=shape_env, 2242 source=source, 2243 symbolic_context=symbolic_context, 2244 trace=trace, 2245 ) 2246 2247 2248_StoragePointer = object 2249 2250 2251# NB: returns fake tensors 2252def run_fallback_kernel( 2253 fake_mode: FakeTensorMode, 2254 func: OpOverload, 2255 flat_args: Sequence[object], 2256 args_spec: PyTree, 2257 orig_not_implemented_exception: RuntimeError, 2258) -> FakeTensor: 2259 # these should all be supported, just to be safe 2260 # avoid fallback for operators which inplace modify metadata 2261 # because the input fake tensors would be umodified 2262 if torch.Tag.inplace_view in func.tags: 2263 raise orig_not_implemented_exception 2264 2265 inp_impls = {} 2266 2267 # Don't use in_kernel_invocation_manager(fake_mode) as we want to do 2268 # REAL compute (not with meta device) 2269 with no_dispatch(): 2270 2271 def to_real_tensor(e: T) -> Union[T, Tensor]: 2272 if fake_mode.is_our_fake(e): 2273 out = torch.zeros_like(e, device=e.fake_device) 2274 if e.is_sparse: 2275 out._coalesced_(e.is_coalesced()) 2276 inp_impls[id(out)] = e 2277 return out 2278 return e 2279 2280 flat_args = [to_real_tensor(a) for a in flat_args] 2281 args, kwargs = pytree.tree_unflatten(flat_args, args_spec) 2282 2283 r = func(*args, **kwargs) 2284 2285 storages: Set[_StoragePointer] = set() 2286 2287 for e in flat_args: 2288 if isinstance(e, Tensor): 2289 if not is_sparse_any(e): 2290 storages.add(e._typed_storage()._cdata) 2291 2292 # TODO: also check metadata change on inputs 2293 # proper aliasing/metadata relationship between outputs and inputs will 2294 # not be set up, bc of conversion to device, unless we can reuse an 2295 # input impl 2296 2297 def map_out(e: T) -> Union[T, FakeTensor]: 2298 if id(e) not in inp_impls and ( 2299 isinstance(e, Tensor) 2300 and not is_sparse_any(e) 2301 and e._typed_storage()._cdata in storages 2302 ): 2303 raise orig_not_implemented_exception 2304 2305 if isinstance(e, Tensor): 2306 if id(e) in inp_impls: 2307 return inp_impls[id(e)] 2308 else: 2309 return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, e) 2310 else: 2311 return e 2312 2313 return pytree.tree_map(map_out, r) 2314 2315 2316# Just for use to allow copying a module to fake tensors, 2317# does not apply elsewhere 2318class FakeCopyMode(TorchFunctionMode): 2319 def __init__(self, fake_mode: FakeTensorMode) -> None: 2320 self.fake_mode = fake_mode 2321 2322 def __torch_function__( 2323 self, 2324 func: OpOverload, 2325 types: Sequence[Type], 2326 args: Sequence[object] = (), 2327 kwargs: Optional[Mapping[str, object]] = None, 2328 ) -> FakeTensor: 2329 kwargs = kwargs if kwargs else {} 2330 2331 # clone will get called in Parameter deepcopy 2332 if func == torch._C.TensorBase.clone: 2333 assert isinstance(args[0], Tensor) 2334 return func( 2335 self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs 2336 ) 2337 elif func == Tensor.__deepcopy__: 2338 assert len(args) == 2 and len(kwargs) == 0 2339 tensor = cast(Tensor, args[0]) 2340 memo = cast(Dict[int, FakeTensor], args[1]) 2341 2342 if id(tensor) in memo: 2343 return memo[id(tensor)] 2344 2345 out = self.fake_mode.from_tensor(tensor, static_shapes=True) 2346 memo[id(tensor)] = out 2347 return out 2348 else: 2349 with torch._C.DisableTorchFunctionSubclass(): 2350 return func(*args, **kwargs) 2351 2352 2353def _device_handler(args: Sequence[object]) -> torch.device: 2354 # NB: Don't use is_our_fake, just serve the fake information 2355 # as is. Notice we don't use 'self'; we use args[0].fake_mode 2356 # because they may not be the same. It would also be possible 2357 # to return NotImplemented here, in which case the FakeTensor 2358 # handler on args[0] would handle it, but we're being nice and 2359 # short-circuiting quickly. 2360 assert len(args) == 1 and isinstance(args[0], FakeTensor) 2361 if args[0].fake_mode.in_kernel_invocation: 2362 return torch.device("meta") 2363 else: 2364 return args[0].fake_device 2365 2366 2367# [subclass inputs] 2368# Suppose we enable fake tensor mode. This means that fake tensor 2369# mode will run first. But what if we do an operation that 2370# involves a tensor subclass that will desugar into normal tensor 2371# operations? Without returning NotImplemented, fake tensor mode will run first, 2372# decide that a conversion was made (since there was a non fake 2373# tensor argument), and report an error that converting non 2374# fake tensor is not supported. What we actually wanted to happen 2375# was to give the subclass a chance to figure out what it wants to 2376# before erroring out. Returning NotImplemented here allows this. 2377def _check_for_subclass(flat_args: Sequence[object]) -> bool: 2378 return any(_check_for_subclass_arg(x) for x in flat_args) 2379 2380 2381def _check_for_subclass_arg(x: object) -> bool: 2382 return ( 2383 not isinstance(x, FakeTensor) 2384 and isinstance(x, Tensor) 2385 and type(x) is not Tensor 2386 and type(x) is not torch.nn.Parameter 2387 ) 2388 2389 2390_DISPATCH_META_HANDLERS = { 2391 torch.ops.prim.device.default: _device_handler, 2392 torch.ops.aten.size.default: lambda args: tuple( 2393 int(s) for s in cast(Tensor, args[0]).size() 2394 ), 2395 torch.ops.aten.stride.default: lambda args: tuple( 2396 int(s) for s in cast(Tensor, args[0]).stride() 2397 ), 2398 torch.ops.aten.storage_offset.default: lambda args: int( 2399 cast(Tensor, args[0]).storage_offset() 2400 ), 2401} 2402 2403_DISPATCH_HANDLE_DIRECTLY = ordered_set( 2404 torch.ops.aten.is_coalesced.default, 2405 torch.ops.aten.dense_dim.default, 2406 torch.ops.aten.sparse_dim.default, 2407) 2408 2409from torch._subclasses.fake_impls import ( # noqa: F401 2410 _device_not_kwarg_ops, 2411 _is_tensor_constructor, 2412 _like_tensor_constructors, 2413 contains_tensor_types, 2414 get_fast_op_impls, 2415 has_meta, 2416 op_implementations_checks, 2417 stride_incorrect_op, 2418) 2419 2420 2421@atexit.register 2422def dump_cache_stats() -> None: 2423 log.info("FakeTensor cache stats:") 2424 log.info(" cache_hits: %s", FakeTensorMode.cache_hits) 2425 log.info(" cache_misses: %s", FakeTensorMode.cache_misses) 2426 bypasses = FakeTensorMode.cache_bypasses 2427 if bypasses: 2428 log.info(" cache_bypasses:") 2429 width = max(len(k) for k in bypasses) 2430 for k, v in sorted(bypasses.items(), key=lambda i: -i[1]): 2431 log.info(" %-*s %s", width + 1, f"{k}:", v) 2432