1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import contextlib 5import dataclasses 6import warnings 7import weakref 8from dataclasses import dataclass 9from typing import ( 10 Any, 11 Callable, 12 ClassVar, 13 ContextManager, 14 Dict, 15 List, 16 Optional, 17 Tuple, 18 Type, 19 TYPE_CHECKING, 20 Union, 21) 22from typing_extensions import TypeAlias 23 24import torch 25from torch._C._autograd import CreationMeta 26from torch._C._functorch import ( 27 _add_batch_dim, 28 _unwrap_functional_tensor, 29 _wrap_functional_tensor, 30 get_unwrapped, 31 is_batchedtensor, 32 is_functorch_wrapped_tensor, 33 is_gradtrackingtensor, 34 is_legacy_batchedtensor, 35 maybe_get_bdim, 36 maybe_get_level, 37 peek_interpreter_stack, 38) 39from torch._logging import trace_structured 40from torch.utils._mode_utils import no_dispatch 41from torch.utils._python_dispatch import is_traceable_wrapper_subclass 42from torch.utils.weak import WeakIdKeyDictionary 43 44 45if TYPE_CHECKING: 46 from torch._C._functorch import CInterpreter 47 from torch._guards import Source 48 49 # Import here to avoid cycle 50 from torch._subclasses.fake_tensor import FakeTensorMode 51 52 # Import the following modules during type checking to enable code intelligence features, 53 # Do not import unconditionally, as they import sympy and importing sympy is very slow 54 from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext 55 56DimList = List 57 58 59def safe_is_leaf(t): 60 try: 61 return t.is_leaf 62 except RuntimeError: 63 # inference mode can trigger this 64 return False 65 66 67def safe_grad(t): 68 with warnings.catch_warnings(): 69 warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") 70 return t.grad 71 72 73def assert_eq(a, b): 74 assert a == b, f"{a} != {b}" 75 76 77def assert_metadata_eq( 78 assert_eq, 79 m1: Union[MetaTensorDesc, torch.Tensor], 80 m2: torch.Tensor, 81 *, 82 skip_symbolic=False, 83 skip_leaf=False, 84): 85 if isinstance(m1, torch.Tensor): 86 m1 = MetaTensorDescriber().describe_tensor(m1) 87 88 def go(m1, m2): 89 assert_eq(m1.dtype, m2.dtype) 90 if not skip_symbolic: 91 assert_eq(m1.shape, m2.shape) 92 assert_eq(m1.requires_grad, m2.requires_grad) 93 if not skip_leaf: 94 assert_eq(m1.is_leaf, m2.is_leaf) 95 # MetaTensorDesc doesn't store grad_fn; inferred from leaf 96 # assert_eq(m1.grad_fn is None, m2.grad_fn is None) 97 assert_eq(m1.is_sparse, m2.is_sparse) 98 assert_eq(m1.is_inference, m2.is_inference()) 99 assert_eq(m1.is_conj, m2.is_conj()) 100 assert_eq(m1.is_neg, m2.is_neg()) 101 assert_eq(m1.grad is not None, safe_grad(m2) is not None) 102 if m1.grad is not None: 103 go(m1.grad, safe_grad(m2)) 104 # TODO: move "assert_eq(m1.layout, m2.layout)" out of sparse 105 # branches (but not ready for prime time yet)... 106 if m1.is_sparse: 107 assert_eq(m1.layout, m2.layout) 108 assert_eq(m1.dense_dim, m2.dense_dim()) 109 assert_eq(m1.sparse_dim, m2.sparse_dim()) 110 assert_eq(m1.is_coalesced, m2.is_coalesced()) 111 elif is_sparse_compressed(m1): 112 assert_eq(m1.layout, m2.layout) 113 assert_eq(m1.dense_dim, m2.dense_dim()) 114 assert_eq(m1.sparse_dim, m2.sparse_dim()) 115 else: 116 if not skip_symbolic: 117 assert_eq(m1.stride, m2.stride()) 118 assert_eq(m1.storage_offset, m2.storage_offset()) 119 assert_eq(m1.is_view, m2._is_view()) 120 if m1.is_view: 121 go(m1.base, m2._base) 122 # TODO: test if is resizable (no direct query for this atm) 123 # TODO: audit AutogradMeta to see if it matches 124 # TODO: test forward AD 125 126 return go(m1, m2) 127 128 129def is_sparse_coo(t): 130 return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo 131 132 133def is_sparse_compressed_layout(layout): 134 return layout in { 135 torch.sparse_csr, 136 torch.sparse_csc, 137 torch.sparse_bsr, 138 torch.sparse_bsc, 139 } 140 141 142def is_sparse_compressed(t): 143 return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout) 144 145 146def is_sparse_any(t): 147 return is_sparse_coo(t) or is_sparse_compressed(t) 148 149 150# Don't use id() directly, because those can get reallocated over time. 151MetaStorageId: TypeAlias = int 152MetaTensorId: TypeAlias = int 153 154 155DESCRIBER_NEXT_ID = 0 156 157 158class MetaTensorDescriber: 159 """ 160 Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc 161 for it, which is enough information to reconstruct a meta tensor/fake tensor 162 corresponding to a Tensor as faithfully as possible. 163 164 This is a stateful conversion object because we keep track of the IDs 165 of the tensors/storages passed to us, so we can consistently give 166 the same ID when we see the same tensor/storage. 167 """ 168 169 def __init__(self, *, copy_data=False): 170 global DESCRIBER_NEXT_ID 171 self.id = DESCRIBER_NEXT_ID 172 DESCRIBER_NEXT_ID += 1 173 self.next_tensor_id: MetaTensorId = 0 174 self.next_storage_id: MetaStorageId = 0 175 # Tensor -> int 176 self.lookup_tensor = WeakIdKeyDictionary() 177 # Storage -> int 178 self.lookup_storage = WeakIdKeyDictionary() 179 self.copy_data = copy_data 180 self.traced_tensors = set() 181 self.traced_storages = set() 182 183 def get_tensor_id(self, t: torch.Tensor): 184 if t not in self.lookup_tensor: 185 self.lookup_tensor[t] = self.next_tensor_id 186 self.next_tensor_id += 1 187 return self.lookup_tensor[t] 188 189 def get_storage_id(self, s: torch.UntypedStorage): 190 if s not in self.lookup_storage: 191 self.lookup_storage[s] = self.next_storage_id 192 self.next_storage_id += 1 193 return self.lookup_storage[s] 194 195 def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False): 196 r = MetaStorageDesc( 197 id=self.get_storage_id(s), 198 size=s.size(), 199 # NB: We don't do the copy yet; copy happens when we start 200 # creating the new storages 201 data=s if self.copy_data else None, 202 ) 203 if trace and r.id not in self.traced_storages: 204 trace_structured( 205 "describe_storage", 206 metadata_fn=lambda: r.as_json(self.id), 207 ) 208 self.traced_storages.add(r.id) 209 return r 210 211 def describe_tensor( 212 self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False 213 ): 214 is_leaf = safe_is_leaf(t) 215 is_view = t._is_view() 216 is_sparse = t.is_sparse 217 layout = t.layout 218 is_nested = t.is_nested 219 is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t) 220 is_functorch_wrapped = is_functorch_wrapped_tensor(t) 221 is_mkldnn = t.is_mkldnn 222 is_batchedtensor_v = is_batchedtensor(t) 223 is_legacy_batchedtensor_v = is_legacy_batchedtensor(t) 224 is_gradtrackingtensor_v = is_gradtrackingtensor(t) 225 is_functorch_batched_or_grad = is_batchedtensor_v or is_gradtrackingtensor_v 226 is_functional = torch._is_functional_tensor(t) 227 228 storage = None 229 # NB: For compatibility, I default this to zero, as sometimes people 230 # still have stuffed zero into storage offset even though the tensor 231 # doesn't meaningfully have an offset 232 storage_offset = 0 233 if not ( 234 is_sparse 235 or is_sparse_compressed_layout(layout) 236 or (is_nested and not is_traceable_wrapper_subclass_v) 237 or is_mkldnn 238 # TODO: TBH, functorch wrapped tensors probably should have 239 # storage associated with them 240 or is_functorch_wrapped 241 or is_legacy_batchedtensor_v 242 ): 243 # NB: We actually don't use storage to do views, but might as well 244 # put it in for accuracy 245 storage = self.describe_storage(t.untyped_storage(), trace=trace) 246 storage_offset = t.storage_offset() # type: ignore[assignment] 247 248 stride = None 249 if not ( 250 is_sparse 251 or is_sparse_compressed_layout(layout) 252 or (is_nested and not is_traceable_wrapper_subclass_v) 253 ): 254 # stride/storage_offset are called from is_functorch_wrapped, 255 # view_from_base, empty_create_subclass, 256 # sym_sizes_strides_storage_offset (empty_create) 257 stride = t.stride() 258 259 # NB: this technically should refer to functorch unwrapped tensor, but 260 # I am (perhaps abusively) using it to store both the functorch and 261 # non-functorch functional tensor 262 unwrapped = None 263 autograd_meta_from = None 264 current_level = None 265 if is_batchedtensor_v or is_gradtrackingtensor_v: 266 unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace) 267 # xla and lazy tensors present as functional tensors, but we want them 268 # to be handled specially 269 elif is_functional and t.device.type not in ("xla", "lazy"): 270 if t._is_view(): 271 raise RuntimeError( 272 "Cannot safely fakify a view because this process drops the view information right now." 273 ) 274 if not is_functorch_wrapped: 275 torch._sync(t) 276 unwrapped = self.describe_tensor( 277 torch._from_functional_tensor(t), trace=trace 278 ) 279 autograd_meta_from = t 280 else: 281 reapply_views = torch._C._functionalization_reapply_views_tls() 282 # NB: has side effects! 283 unwrapped = self.describe_tensor( 284 _unwrap_functional_tensor(t, reapply_views), trace=trace 285 ) 286 # TODO: It's pretty suspicious that functional tensors don't have 287 # valid level and thus we just grab whatever the current level 288 # is 289 current_level = torch._C._functorch.current_level() 290 291 maybe_functorch_stack = None 292 if is_functorch_wrapped: 293 with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack: 294 pass 295 296 attrs = None 297 ctx = None 298 type_v = None 299 if is_traceable_wrapper_subclass_v: 300 assert hasattr(t, "__tensor_flatten__") 301 raw_attrs, ctx = t.__tensor_flatten__() 302 attrs = { 303 attr: self.describe_tensor(getattr(t, attr), trace=trace) 304 for attr in raw_attrs 305 } 306 type_v = type(t) 307 308 from torch.nested._internal.nested_tensor import _tensor_symint_registry 309 310 # TODO: Is it important to enable torch.inference_mode before querying 311 # these values? 312 r = MetaTensorDesc( 313 id=self.get_tensor_id(t), 314 storage=storage, 315 is_inference=t.is_inference(), 316 is_leaf=is_leaf, 317 requires_grad=t.requires_grad, 318 # NB: ndim should be OK too but there is a disaster at 319 # python test/dynamo/test_subclasses.py -k test_user_overidden_property_unsupported 320 # Actually, this means that we have a little bit of a problem 321 # here, which is that there is some sensitivity to how exactly an 322 # access is done if you have a __torch_function__ subclass. Maybe 323 # should disable torch function before doing accesses? 324 ndim=t.dim(), 325 dtype=t.dtype, 326 is_sparse=is_sparse, 327 is_mkldnn=is_mkldnn, 328 is_functorch_wrapped=is_functorch_wrapped, 329 is_batchedtensor=is_batchedtensor_v, 330 is_legacy_batchedtensor=is_legacy_batchedtensor_v, 331 is_gradtrackingtensor=is_gradtrackingtensor_v, 332 is_view=is_view, 333 is_conj=t.is_conj(), 334 is_neg=t.is_neg(), 335 is_parameter=isinstance(t, torch.nn.Parameter), 336 is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v, 337 is_nested=is_nested, 338 nested_int=( 339 _tensor_symint_registry[t].node.nested_int() 340 if t in _tensor_symint_registry 341 else None 342 ), 343 is_functional=is_functional, 344 layout=layout, 345 device=t.device, 346 size=t.size(), 347 stride=stride, 348 storage_offset=storage_offset, 349 dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())), 350 sparse_dim=( 351 t.sparse_dim() if t.is_sparse or is_sparse_compressed(t) else None 352 ), 353 dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None, 354 is_coalesced=t.is_coalesced() if t.is_sparse else None, 355 # TODO: I actually think recursing here is correct, but we have at 356 # least an infinite cycle from base -> values -> base 357 # https://github.com/pytorch/pytorch/issues/122089 358 crow_indices=( 359 self.describe_tensor(t.crow_indices(), recurse=False, trace=trace) 360 if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} 361 else None 362 ), 363 col_indices=( 364 self.describe_tensor(t.col_indices(), recurse=False, trace=trace) 365 if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} 366 else None 367 ), 368 ccol_indices=( 369 self.describe_tensor(t.ccol_indices(), recurse=False, trace=trace) 370 if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} 371 else None 372 ), 373 row_indices=( 374 self.describe_tensor(t.row_indices(), recurse=False, trace=trace) 375 if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} 376 else None 377 ), 378 values=( 379 self.describe_tensor(t.values(), recurse=False, trace=trace) 380 if recurse and is_sparse_compressed(t) 381 else None 382 ), 383 grad=( 384 self.describe_tensor(safe_grad(t), trace=trace) 385 if safe_grad(t) is not None 386 else None 387 ), 388 creation_meta=( 389 torch._C._autograd._get_creation_meta(t) if t._is_view() else None 390 ), 391 unwrapped=unwrapped, 392 level=( 393 maybe_get_level(t) 394 if is_batchedtensor_v or is_gradtrackingtensor_v 395 else None 396 ), 397 bdim=maybe_get_bdim(t) if is_batchedtensor_v else None, 398 base=( 399 self.describe_tensor(t._base, trace=trace) 400 if recurse and t._is_view() and t._base is not None 401 else None 402 ), 403 fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t), 404 view_func=t._view_func_unsafe, 405 attrs=attrs, 406 ctx=ctx, 407 type=type_v, 408 # NB: even if functorch is enabled, don't actually save the 409 # interpreter stack here unless we are actually functorch wrapped; 410 # it's irrelevant for non-functorch stuff 411 functorch_stack=maybe_functorch_stack, 412 autograd_meta_from=autograd_meta_from, 413 current_level=current_level, 414 data=t if self.copy_data else None, 415 ) 416 if trace and r.id not in self.traced_tensors: 417 trace_structured( 418 "describe_tensor", 419 metadata_fn=lambda: r.as_json(self.id), 420 ) 421 self.traced_tensors.add(r.id) 422 return r 423 424 425@dataclass(frozen=True) 426class MetaStorageDesc: 427 id: MetaStorageId 428 size: int 429 # NB: this is only populated with copy_data True, it is not directly 430 # serializable in JSON, you want to do something special here anyway 431 data: Optional[torch.UntypedStorage] 432 433 def as_json(self, describer_id): 434 return { 435 "id": self.id, 436 "describer_id": describer_id, 437 "size": self.size if isinstance(self.size, int) else repr(self.size), 438 } 439 440 441@dataclass(frozen=True) 442class MetaTensorDesc: 443 id: MetaTensorId 444 ndim: int 445 dtype: torch.dtype 446 device: torch.device 447 448 # NB: Sometimes, size, stride and storage_offset contain SymInt, in which 449 # case this is NOT serializable. That only happens when you're 450 # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we 451 # can get rid of this use case entirely. Notably, even if we are 452 # fakeifying a real tensor into a fake tensor with symbolic shapes, the 453 # size here is NOT dynamic 454 # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic 455 # goes through this codepath. But it really should not LOL. 456 # NB: size could potentially be None as you can override it and make it 457 # throw an error, but we don't currently have any subclasses that do this 458 # except C++ nested tensor but we're going to have nested int to make this 459 # defined on NJT 460 size: Tuple[int, ...] 461 dynamo_dynamic_indices: List[int] 462 463 layout: torch.layout = torch.strided 464 is_inference: bool = False 465 is_leaf: bool = False 466 requires_grad: bool = False 467 is_sparse: bool = False 468 is_mkldnn: bool = False 469 is_functorch_wrapped: bool = False 470 is_batchedtensor: bool = False 471 is_legacy_batchedtensor: bool = False 472 is_gradtrackingtensor: bool = False 473 is_view: bool = False 474 is_nested: bool = False 475 # We eagerly symbolicize the associated nested int for e.g. offsets / lengths 476 # metadata if that offsets is already associated with a nested int. 477 # See test_construct_from_jagged_with_input_offsets_mixed_case. 478 nested_int: Optional[int] = None 479 is_traceable_wrapper_subclass: bool = False 480 is_functional: bool = False 481 is_conj: bool = False 482 is_neg: bool = False 483 is_parameter: bool = False 484 stride: Optional[Tuple[int, ...]] = None 485 storage_offset: int = 0 486 # NB: We have a choice whether or not to store the id or a direct pointer 487 # to the data structure. For ease of use, we store the data structure, 488 # but this means that when we serialize, we have to swizzle these pointers 489 # back into ids (so we have accurate aliasing relationships) 490 storage: Optional[MetaStorageDesc] = None 491 sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed 492 dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed 493 is_coalesced: Optional[bool] = None # is_sparse 494 crow_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed 495 col_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed 496 ccol_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed 497 row_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed 498 values: Optional[MetaTensorDesc] = None # is_sparse_compressed 499 unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped 500 bdim: Optional[int] = None # is_functorch_wrapped 501 base: Optional[MetaTensorDesc] = None # is_view 502 attrs: Optional[Dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass 503 creation_meta: Optional[CreationMeta] = None 504 grad: Optional[MetaTensorDesc] = None 505 506 # Everything below is NOT serializable, need some more work 507 508 _UNSERIALIZABLE: ClassVar[List[str]] = [ 509 "ctx", 510 "type", 511 "fake_mode", 512 "view_func", 513 "level", 514 "current_level", 515 "functorch_stack", 516 "autograd_meta_from", 517 "data", 518 "nested_int", 519 ] 520 521 ctx: Optional[object] = None # is_traceable_wrapper_subclass 522 type: Optional[Type] = None # is_traceable_wrapper_subclass 523 fake_mode: Optional[FakeTensorMode] = None 524 view_func: Optional[ 525 Callable[ 526 [ 527 torch.Tensor, 528 Callable[[int], int], 529 Callable[[torch.Tensor], torch.Tensor], 530 ], 531 torch.Tensor, 532 ] 533 ] = None 534 # level looks serializable, but actually it is meaningless without 535 # the functorch_stack below 536 level: Optional[int] = None # is_functorch_wrapped 537 current_level: Optional[int] = None 538 functorch_stack: Optional[List[CInterpreter]] = None 539 autograd_meta_from: Optional[torch.Tensor] = None 540 541 # This is only populated on copy_data, and typically is not used at all, 542 # except for some of our meta-ification paths that don't properly use 543 # storage (pro-tip: you should use storage) 544 data: Optional[torch.Tensor] = None 545 546 # Faithfully serializing functorch tensors will not be too difficult. 547 # We only need to consider grad/vmap interpreters, and their internal 548 # state is only bools (mostly what the grad enabled/disabled state 549 # should be in the lower layer). Beyond that, tensors just need to 550 # precisely indicate which particular interpreter they correspond 551 # to (we then replace level with a pointer to the interpreter stack.) 552 # However, this use of functorch is very "non-lexical" so it's not 553 # entirely clear how to make it all lexical again, so we haven't done 554 # it for now. 555 556 # NB: This will reference numeric IDs, and it is assumed that you've 557 # already serialized everything this recursively references 558 def as_json(self, describer_id): 559 def json(k, v): 560 # Some best-effort debugging serialization for unserializable 561 # fields (feel free to add other special cases as appropriate) 562 if k in ["data", "autograd_meta_from"]: 563 return None # never repr these 564 if k in set(MetaTensorDesc._UNSERIALIZABLE): 565 return repr(v) 566 if isinstance(v, (torch.device, torch.dtype, torch.layout)): 567 return repr(v) 568 if isinstance(v, torch.SymInt): 569 return repr(v) 570 if isinstance(v, (tuple, list)): 571 return [json(k, v1) for v1 in v] 572 if isinstance(v, (MetaStorageDesc, MetaTensorDesc)): 573 return v.id 574 if isinstance(v, CreationMeta): 575 return str(v) 576 if k == "attrs" and isinstance(v, dict): 577 return {k1: v1.id for k1, v1 in v.items()} 578 return v 579 580 r = { 581 field.name: json(field.name, getattr(self, field.name)) 582 for field in dataclasses.fields(self) 583 if not ( 584 getattr(self, field.name) is field.default 585 or ( 586 field.name == "dynamo_dynamic_indices" 587 and not getattr(self, field.name) 588 ) 589 ) 590 } 591 r.update({"describer_id": describer_id}) 592 return r 593 594 @property 595 def shape(self): 596 return self.size 597 598 599# A more faithful reproduction would do a copy on the entire 600# storage, but this needs to be done carefully because the 601# underlying storage could have larger extent than is implied 602# by size/stride. The real fix is to properly call 603# meta_storage recursively here. 604# 605# These "safe" functions are intended to be used under no_dispatch() mode. 606# The no_dispatch() here is intended to prevent ambient fake tensor mode from 607# fakeifying the operation. But if we are given an honest to goodness 608# FakeTensor as src, we MUST NOT run the copy/clone operation. A better way 609# to do this would be to not use no_dispatch and instead just disable fake 610# tensor mode only (allowing for subclass dispatch to occur) 611def _safe_copy(dst, src): 612 if type(src) is not torch.Tensor: 613 return 614 dst.copy_(src) 615 616 617def _safe_clone(src): 618 if type(src) is not torch.Tensor: 619 return None 620 return src.clone() 621 622 623# This is a class for converting multiple tensors into meta tensors which 624# share the same view/storage structure. The operation model is you allocate 625# one of these, and then call it repeatedly on all the tensors you want to 626# convert. It's important to use the same object for tensors you want to 627# share storage because this is how we correlate shared storages to the same 628# meta storages. This class will hold weak references to cached tenosrs 629# and tensor storages. 630class MetaConverter: 631 def __init__(self, *, copy_data: bool = False): 632 # Maps MetaStorageId to UntypedStorage 633 self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() 634 # Maps MetaTensorId to torch.Tensor (typically a meta tensor or 635 # FakeTensor) 636 self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() 637 self.hit = 0 638 self.miss = 0 639 self.del_hook = None 640 self.arg_cnt = 0 641 # Ensures real_storage/real_tensor are populated on the resulting 642 # metaified storage/tensor. The naming of this attribute is load 643 # bearing: FakeTensor relies on real tensor being set to exactly this 644 # value 645 self.copy_data = copy_data 646 self.describer = MetaTensorDescriber(copy_data=copy_data) 647 648 def successful(self): 649 return self.hit > 0 and self.miss == 0 650 651 def get_tensor_memo(self, t: MetaTensorDesc): 652 return self.tensor_memo.get(t.id, None) 653 654 def set_tensor_memo(self, t: MetaTensorDesc, v): 655 self.tensor_memo[t.id] = v 656 657 def get_storage_memo(self, s: MetaStorageDesc): 658 return self.storage_memo.get(s.id, None) 659 660 def set_storage_memo(self, s: MetaStorageDesc, v): 661 self.storage_memo[s.id] = v 662 663 def meta_storage(self, s: MetaStorageDesc, callback): 664 # If we are fakeifying a tensor that has a secretly-zero-sized storage, 665 # Need to make sure to resize the meta storage too. 666 if self.get_storage_memo(s) is None: 667 r_s = callback( 668 lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"), 669 ).untyped_storage() 670 if self.copy_data: 671 # NB: no_dispatch is needed because internally storage copy is 672 # implemented as Tensor operations 673 with torch.no_grad(), no_dispatch(): 674 assert s.data is not None 675 r_s.real_storage = s.data.clone() 676 self.set_storage_memo(s, r_s) 677 return r_s 678 else: 679 return self.get_storage_memo(s) 680 681 # This function assumes that it's possible to do the conversion 682 # NB: name here is used in a conventional way by Dynamo; it corresponds 683 # precisely to the Source.name() of the tensor we're fakeifying and 684 # corresponds to a valid Python expression. When we construct sub-names 685 # as part of this process, we will maintain this invariant! (Even though 686 # other users of this may not need it this property to be upheld.) 687 def meta_tensor( 688 self, 689 t: MetaTensorDesc, 690 shape_env: Optional[ShapeEnv] = None, 691 callback=lambda t: t(), 692 source: Optional[Source] = None, 693 symbolic_context: Optional[SymbolicContext] = None, 694 ): 695 if source is None: 696 from torch._dynamo.source import ConstantSource 697 698 # TODO: make a dedicated UnknownSource for this? 699 source = ConstantSource( 700 f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" 701 ) 702 703 # This indicates you set no_dispatch() before calling into this 704 # function. This is an error: we may be creating fake tensors and 705 # will perform operations on them which need fake tensor mode to 706 # be active. You will segfault if you are in a no_dispatch() block. 707 assert not torch._C._dispatch_tls_local_exclude_set().has( 708 torch._C.DispatchKey.Python 709 ) 710 arg_cnt = self.arg_cnt 711 self.arg_cnt += 1 712 713 # When we make as_strided calls, we end up generating a guard 714 # that the new as_strided tensor is in bounds for the old storage 715 # for the base (since as_strided calls can "bust" out of their 716 # bounding box.) This guard is unnecessary: if a user is able 717 # to provide us a tensor with the view base setup this way, we 718 # don't need to produce a guard, because the fact that they 719 # were able to produce the view base means its in bounds. 720 # 721 # Now, ordinarily, this guard would be harmless. However, the 722 # generated guard refers to variables bound on the base variable. 723 # At the moment, Dynamo doesn't actually guard on x._base, because 724 # according to Voz this results in a lot of spurious invalidations, 725 # and also if the user doesn't directly make use of _base, its 726 # pointless anyway (because programs should be parametric over 727 # whether or not the input tensor is a view or not--unless you're 728 # mutating the input, but that's a whole 'nother ballgame). So 729 # for expediency, we suppress these guards so we don't have to 730 # deal with this (yet, anyway.) 731 # 732 # NB: An old version of this code suppressed guards for ALL operations 733 # happening during meta conversion, not just as_strided calls. 734 # This is too aggressive: we do duck sizing and 0/1 simplification 735 # as we allocate variables, and we do need to register guards for 736 # these cases. 737 maybe_suppress: Callable[[], Any] = contextlib.nullcontext 738 if shape_env is not None: 739 maybe_suppress = shape_env.suppress_guards 740 741 def sym_sizes_strides_storage_offset( 742 t: MetaTensorDesc, src, symbolic_context=symbolic_context 743 ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: 744 assert t.stride is not None 745 if shape_env is not None: 746 fake_mode = t.fake_mode 747 if fake_mode is not None and fake_mode.shape_env is shape_env: 748 # Don't reallocate the sizes; the shape envs are the same, 749 # so reuse the old sizes/strides/etc 750 return (t.size, t.stride, t.storage_offset) 751 else: 752 # TODO: deduplicate this 753 t_size = tuple( 754 shape_env._maybe_specialize_sym_int_with_hint(sz) 755 for sz in t.size 756 ) 757 t_stride = tuple( 758 shape_env._maybe_specialize_sym_int_with_hint(sd) 759 for sd in t.stride 760 ) 761 t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint( 762 t.storage_offset 763 ) 764 return shape_env._create_symbolic_sizes_strides_storage_offset( 765 t_size, 766 t_stride, 767 t_storage_offset, 768 [d in t.dynamo_dynamic_indices for d in range(t.ndim)], 769 src, 770 symbolic_context=symbolic_context, 771 ) 772 else: 773 return (t.size, t.stride, t.storage_offset) 774 775 def empty_create( 776 inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context 777 ): 778 ( 779 inner_sizes, 780 inner_strides, 781 inner_storage_offset, 782 ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context) 783 return torch.empty_strided( 784 inner_sizes, 785 inner_strides, 786 dtype=inner_t.dtype, 787 device="meta", 788 ) 789 790 # Creates a subclass instance with empty inner tensors according to the specified 791 # symbolic context. 792 def empty_create_subclass( 793 t: MetaTensorDesc, 794 outer_size, 795 outer_stride, 796 symbolic_context=symbolic_context, 797 callback=callback, 798 source=source, 799 ): 800 from torch._dynamo.source import AttrSource 801 from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext 802 803 assert t.attrs is not None 804 assert t.type is not None 805 # NB: t.ctx could be None if the subclass in question has no 806 # meaningful context 807 808 # Note: transform_subclass will use __tensor_unflatten__ to generate 809 # a fresh subclass wrapper with outer sizes / strides according to the 810 # outer symbolic context (passed in to this function). Inner size / stride 811 # / storage offset symbols are allocated according to the appropriate inner 812 # symbolic contexts, after which the checks in transform_subclass() will 813 # relate them to the outer metadata as possible. 814 # 815 # Morally, the code here is same as transform_subclass, but we've 816 # written it from scratch to read EmptyCreateSubclass 817 outer_size = outer_size if outer_size is not None else t.size 818 outer_stride = outer_stride if outer_stride is not None else t.stride 819 820 assert symbolic_context is None or isinstance( 821 symbolic_context, SubclassSymbolicContext 822 ) 823 824 def _empty_create_subclass( 825 t, outer_size, outer_stride, symbolic_context, callback, source 826 ): 827 # We are hitting plain meta_desc tensor so actually 828 # create a tensor here. 829 if t.attrs is None: 830 return self.meta_tensor( 831 t, 832 shape_env=shape_env, 833 callback=callback, 834 source=source, 835 symbolic_context=symbolic_context, 836 ) 837 838 inner_tensors = {} 839 for attr, meta_tensor_desc in t.attrs.items(): 840 current_context = None 841 if symbolic_context is not None: 842 current_context = symbolic_context.inner_contexts[attr] 843 844 current_source = AttrSource(source, attr) 845 new_empty_tensor = _empty_create_subclass( 846 meta_tensor_desc, 847 meta_tensor_desc.size, 848 meta_tensor_desc.stride, 849 current_context, 850 callback, 851 current_source, 852 ) 853 inner_tensors[attr] = new_empty_tensor 854 855 return t.type.__tensor_unflatten__( 856 inner_tensors, t.ctx, outer_size, outer_stride 857 ) 858 859 sub = _empty_create_subclass( 860 t, outer_size, outer_stride, symbolic_context, callback, source 861 ) 862 863 # NB: Purposefully guard here to simplify the inner / outer symbols. 864 # Using sym_eq() for symbolic comparison can result in an expression that's too 865 # difficult to guard on, so we use == here. 866 assert sub.shape == outer_size, ( 867 f"Expected return value from {t.type}__tensor_unflatten__() to have " 868 f"shape equal to {outer_size}, but got: {sub.shape}" 869 ) 870 assert sub.stride() == outer_stride, ( 871 f"Expected return value from {t.type}__tensor_unflatten__() to have " 872 f"stride equal to {outer_stride}, but got: {sub.stride()}" 873 ) 874 875 return sub 876 877 # Returns an all-dynamic symbolic context used for metafying the given tensor with 878 # fully dynamic dims. This is useful when fake-ifying intermediate tensors in 879 # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we 880 # don't want to over-specialize during view replay. 881 def all_dynamic_symbolic_context( 882 t: MetaTensorDesc, source, shape_env, callback 883 ): 884 from torch._dynamo.source import AttrSource 885 from torch.fx.experimental.symbolic_shapes import ( 886 DimDynamic, 887 StatelessSymbolicContext, 888 SubclassSymbolicContext, 889 ) 890 891 view_base_context: Optional[SymbolicContext] = None 892 if t.is_view: 893 assert t.base is not None 894 view_base_context = all_dynamic_symbolic_context( 895 t.base, AttrSource(source, "_base"), shape_env, callback 896 ) 897 898 t_symbolic_context: SymbolicContext 899 t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim 900 if t.is_traceable_wrapper_subclass: 901 assert t.attrs is not None 902 inner_contexts: Dict[str, SymbolicContext] = {} 903 for attr, inner in t.attrs.items(): 904 assert isinstance(attr, str) 905 inner_contexts[attr] = all_dynamic_symbolic_context( 906 inner, AttrSource(source, attr), shape_env, callback 907 ) 908 t_symbolic_context = SubclassSymbolicContext( 909 dynamic_sizes=t_dynamic_sizes, 910 constraint_sizes=[None] * t.ndim, 911 inner_contexts=inner_contexts, # type: ignore[arg-type] 912 tensor_source=source, 913 view_base_context=view_base_context, 914 ) 915 else: 916 t_symbolic_context = StatelessSymbolicContext( 917 dynamic_sizes=t_dynamic_sizes, 918 constraint_sizes=[None] * t.ndim, 919 view_base_context=view_base_context, 920 ) 921 922 return t_symbolic_context 923 924 # Returns a fake-ified version of an input view tensor t, given an already fake-ified 925 # base. At a high level, we want two things: 926 # 1. fake_t should have the same view relationship to the given fake base as the 927 # input t has to its _base. 928 # 2. fake_t should have symbolic sizes / strides / storage offset according to the 929 # appropriate symbolic context (i.e. from the automatic dynamic algorithm). 930 # 931 # We currently take different strategies across view types: 932 # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an 933 # as_strided() call on the fake-ified base, passing symbolic metadata. 934 # * For views involving subclasses, perform view replay using view funcs to 935 # achieve (1). It's necessary for (2) to swap out any closed-over state in 936 # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this 937 # avoids specialization (and thus over-eager simplification of symbols) that 938 # could occur during view replay on the fake-ified base. 939 # 940 # Examples: 941 # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled 942 # with an as_strided() call on the fake base passing symbolic metadata. 943 # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg 944 # is made symbolic to avoid invalid specialization and view replay is then 945 # done to reconstruct the view. 946 # * _nested_from_jagged(values, offsets) is a dense -> subclass view 947 # that returns a subclass instance from a dense values tensor. The offsets 948 # tensor is closed over in the view func, as it can be considered view metadata. 949 # First, the offsets tensor is fake-ified according to the inner symbolic 950 # context and with the correct relationship to the outer size / stride metadata. 951 # Then view replay is done, swapping in the fake offsets so the view replay output 952 # is fully fake with no invalid specialization. 953 def view_from_base( 954 base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env 955 ): 956 # fake-ify t's metadata according to the outer symbolic context 957 (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( 958 t, source 959 ) 960 if ( 961 not t.is_traceable_wrapper_subclass 962 and not is_traceable_wrapper_subclass(base) 963 ): 964 # Dense -> Dense view case uses as_strided() to construct view relationship. 965 # TODO: Change this logic to use view replay for consistency? 966 # It's likely there is no view func available. 967 with maybe_suppress(): 968 return base.as_strided(sizes, strides, storage_offset) 969 970 from torch._dynamo.source import EphemeralSource 971 from torch.fx.experimental.symbolic_shapes import ( 972 StatelessSymbolicContext, 973 sym_eq, 974 ) 975 976 def symint_visitor_fn(s): 977 nonlocal symbolic_context 978 from torch.fx.experimental.symbolic_shapes import DimDynamic 979 980 all_static_sizes = ( 981 symbolic_context is not None 982 and isinstance(symbolic_context, StatelessSymbolicContext) 983 and all( 984 x is DimDynamic.STATIC for x in symbolic_context.dynamic_sizes 985 ) 986 ) 987 # Can't just rely on shape env being None - dynamo always initializes it 988 if all_static_sizes or shape_env is None: 989 return s 990 991 # NB: The symbol here is expected to be simplified out because we a priori 992 # allocate inner and outer symbols according to the appropriate symbolic 993 # contexts and prefer those over this symbol during symbol simplification 994 # (via usage of EphemeralSource below). This -shouldn't- happen, but if 995 # this symbol somehow leaks out beyond the view tensor's shape metadata, our 996 # assumption of it being simplified out will fail and it may be guarded on, 997 # which will hard error. 998 sym_source = EphemeralSource("symint_visitor_fn") 999 1000 symbol = shape_env.create_symbol(s, sym_source, positive=None) 1001 return shape_env.create_symintnode(symbol, hint=s, source=sym_source) 1002 1003 real_to_fake_mapping = {} 1004 if t.is_traceable_wrapper_subclass: 1005 assert t.attrs is not None 1006 # NB: t.ctx could be None if the subclass in question has no 1007 # meaningful context 1008 assert t.type is not None 1009 1010 # Fake-ify t naively here; this is only done so we can get fake-ified inner 1011 # tensors with the correct relationships to the outer sizes / strides for use 1012 # in view replay. It's done beforehand here because it's not easy to do when 1013 # visiting tensors one-by-one during view replay. 1014 # 1015 # Example: 1016 # Consider a Dense -> NJT view. NJT has (values, offsets) components and we 1017 # want a view of values with the offsets closed over. As the offsets component 1018 # is needed to describe the output view, it's important that it's fakeified 1019 # correctly. 1020 fake_t = empty_create_subclass( 1021 t, outer_size=sizes, outer_stride=strides 1022 ) 1023 attrs, _ = fake_t.__tensor_flatten__() 1024 for attr in attrs: 1025 real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr) 1026 1027 def tensor_visitor_fn( 1028 visited_t: torch.Tensor, 1029 # These arguments are never passed, we just use them to close 1030 # over these relevant values 1031 shape_env=shape_env, 1032 callback=callback, 1033 ): 1034 # It's possible to close over an undefined tensor (e.g. NJT's lengths). 1035 if visited_t is None: 1036 return None 1037 1038 # NB: visited_t being a Tensor here is very naughty! Should 1039 # have already been described 1040 1041 # Fake inner tensors of view subclasses will come from the mapping built above. 1042 visited_id = self.describer.get_tensor_id(visited_t) 1043 fake_visited_t = real_to_fake_mapping.get(visited_id, None) 1044 if fake_visited_t is not None: 1045 return fake_visited_t 1046 1047 visited_desc = self.describer.describe_tensor(visited_t) 1048 1049 # For other closed-over tensor state, fake-ify it as all dynamic with an 1050 # ephemeral source. This avoids invalid specialization during view replay. 1051 # If we find that in practice the usage of ephemeral sources isn't enough 1052 # to guarantee that we don't have guards on these symbols, we may need to 1053 # explicitly suppress guards (as is done for _base in the dense -> dense 1054 # view case). 1055 temp_source = EphemeralSource("tensor_visitor_fn") 1056 return self.meta_tensor( 1057 visited_desc, 1058 shape_env, 1059 callback, 1060 source=temp_source, 1061 symbolic_context=all_dynamic_symbolic_context( 1062 visited_desc, temp_source, shape_env, callback 1063 ), 1064 ) 1065 1066 # Replay the view, swapping out any non-symbolic SymInts or real tensors 1067 # for symbolic SymInts or fake tensors. 1068 assert t.view_func is not None 1069 # NB: we do NOT suppress guards here, we need to remove ephemeral 1070 # sources 1071 fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn) 1072 1073 # Ensure the output has symbolic shapes according to the outer symbolic context. 1074 # These checks should simplify out any symbols created for closed-over view func 1075 # SymInts. 1076 torch._check(sym_eq(fake_t.size(), sizes)) 1077 torch._check(sym_eq(fake_t.stride(), strides)) 1078 torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) 1079 return fake_t 1080 1081 if self.get_tensor_memo(t) is None: 1082 GRAD_TENSOR_SENTINEL_VALUE = -2 1083 1084 with torch.inference_mode(t.is_inference): 1085 if t.is_sparse: 1086 is_leaf = t.is_leaf 1087 1088 # The lambda function below is similar to 1089 # `t.to(device='meta')` except the latter 1090 # preserves nnz value 1091 r = callback( 1092 lambda: torch.ops.aten._sparse_coo_tensor_with_dims( 1093 t.sparse_dim, 1094 t.dense_dim, 1095 t.size, 1096 dtype=t.dtype, 1097 layout=torch.sparse_coo, 1098 device="meta", 1099 ) 1100 ) 1101 if self.copy_data: 1102 # Pray that sparse clone doesn't lose information 1103 assert t.data is not None 1104 with torch.no_grad(), no_dispatch(): 1105 r.real_tensor = _safe_clone(t.data) 1106 assert safe_is_leaf(r), "the callback you passed in doesn't detach" 1107 # Note [is_coalesced is dispatched] 1108 # Strangely enough, is_coalesced() is a dispatched operator, 1109 # which means that it will get caught by fake tensor mode. 1110 # Ordinarily this would error, but there's some logic in 1111 # fake tensor ensure this doesn't happen. 1112 r._coalesced_(t.is_coalesced) 1113 if t.requires_grad: 1114 r.requires_grad = True 1115 if t.requires_grad and not is_leaf: 1116 # This should probably use DelayedError, 1117 # but clone is fine for now for sparse tensors. 1118 # (DelayedError does not work for sparse because it causes 1119 # the Fake sparse tensor to "lose" its fakeness) 1120 r = r.clone() 1121 with torch.enable_grad(): 1122 r._coalesced_(t.is_coalesced) 1123 elif is_sparse_compressed_layout(t.layout): 1124 is_leaf = t.is_leaf 1125 1126 if t.layout in {torch.sparse_bsr, torch.sparse_bsc}: 1127 assert t.sparse_dim is not None 1128 assert t.dense_dim is not None 1129 assert t.values is not None 1130 batch_dim = t.ndim - t.sparse_dim - t.dense_dim 1131 blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3] 1132 else: 1133 blocksize = () 1134 if t.layout in {torch.sparse_csr, torch.sparse_bsr}: 1135 assert t.crow_indices is not None 1136 index_dtype = t.crow_indices.dtype 1137 else: 1138 assert t.ccol_indices is not None 1139 index_dtype = t.ccol_indices.dtype 1140 1141 r = callback( 1142 lambda: torch.ops.aten._sparse_compressed_tensor_with_dims( 1143 0, 1144 t.dense_dim, 1145 t.shape, 1146 blocksize, 1147 index_dtype, 1148 layout=t.layout, 1149 dtype=t.dtype, 1150 device="meta", 1151 ) 1152 ) 1153 if self.copy_data: 1154 # Pray sparse clone doesn't lose information 1155 assert t.data is not None 1156 with torch.no_grad(), no_dispatch(): 1157 r.real_tensor = _safe_clone(t.data) 1158 assert safe_is_leaf(r), "the callback you passed in doesn't detach" 1159 if t.requires_grad: 1160 r.requires_grad = True 1161 if t.requires_grad and not is_leaf: 1162 r = torch._C._functions.DelayedError( 1163 "Internal error: Tried to backward() through example input", 1164 1, 1165 )(r) 1166 elif t.is_nested and not t.is_traceable_wrapper_subclass: 1167 # TODO: Handle this better in Dynamo? 1168 # There are checks there now, but this can still be triggered by a dense 1169 # tensor graph input that is a view of a strided NT. 1170 from torch._dynamo.exc import unimplemented 1171 1172 unimplemented( 1173 "strided nested tensors are not supported by meta conversion" 1174 ) 1175 elif t.is_mkldnn: 1176 is_leaf = t.is_leaf 1177 sizes, strides, _storage_offset = sym_sizes_strides_storage_offset( 1178 t, source 1179 ) 1180 # TODO: This doesn't seem right, where's the MKLDNN'ness 1181 # lol 1182 r = callback( 1183 lambda: torch.empty_strided( 1184 sizes, strides, dtype=t.dtype, device="meta" 1185 ) 1186 ) 1187 if self.copy_data: 1188 with torch.no_grad(), no_dispatch(): 1189 assert t.size is not None 1190 assert t.stride is not None 1191 r.real_tensor = torch.empty_strided( 1192 t.size, t.stride, dtype=t.dtype, device=t.device 1193 ) 1194 assert t.data is not None 1195 _safe_copy(r.real_tensor, t.data) 1196 assert safe_is_leaf(r), "the callback you passed in doesn't detach" 1197 if t.requires_grad: 1198 r.requires_grad = True 1199 if t.requires_grad and not is_leaf: 1200 r = torch._C._functions.DelayedError( 1201 "Internal error: Tried to backward() through example input", 1202 1, 1203 )(r) 1204 elif t.is_functorch_wrapped: 1205 if t.is_view: 1206 from torch._dynamo.exc import unimplemented 1207 1208 unimplemented( 1209 "view functorch tensors are not supported by meta conversion" 1210 ) 1211 1212 # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) 1213 # in a FakeTensor 1214 def _to_fake_tensor(t: MetaTensorDesc): 1215 # TODO: why aren't the recursive calls going to 1216 # meta_tensor 1217 if t.is_batchedtensor: 1218 assert t.unwrapped is not None 1219 assert t.level is not None 1220 assert t.bdim is not None 1221 ft = _to_fake_tensor(t.unwrapped) 1222 lvl = t.level 1223 bdim = t.bdim 1224 # You cannot create functorch tensors without 1225 # having the ambient funtorch interpreter stack 1226 # available, as the level refers to things in the 1227 # stack 1228 with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( 1229 t.functorch_stack 1230 ): 1231 r = _add_batch_dim(ft, bdim, lvl) 1232 elif t.is_gradtrackingtensor: 1233 assert t.unwrapped is not None 1234 assert t.level is not None 1235 disable_functorch = torch._C._DisableFuncTorch 1236 with disable_functorch(): 1237 ft = _to_fake_tensor(t.unwrapped) 1238 lvl = t.level 1239 if lvl == GRAD_TENSOR_SENTINEL_VALUE: 1240 r = ft 1241 else: 1242 with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( 1243 t.functorch_stack 1244 ): 1245 r = torch._C._functorch._wrap_for_grad(ft, lvl) 1246 1247 is_leaf = t.is_leaf 1248 if t.requires_grad and safe_is_leaf(r): 1249 r.requires_grad = True 1250 elif t.requires_grad and not is_leaf: 1251 r = torch._C._functions.DelayedError( # type: ignore[assignment] 1252 "Internal error: Tried to backward() through example input", 1253 1, 1254 )( 1255 r # type: ignore[arg-type] 1256 ) 1257 elif t.is_functional: 1258 assert t.unwrapped is not None 1259 assert t.current_level is not None 1260 ft = self.meta_tensor( 1261 t.unwrapped, 1262 shape_env=shape_env, 1263 callback=callback, 1264 # NB: reuse these exactly, we treat the 1265 # functional tensor as "invisible". 1266 # TODO: Actually this all probably doesn't 1267 # work, take a closer look. 1268 source=source, 1269 symbolic_context=symbolic_context, 1270 ) 1271 r = _wrap_functional_tensor(ft, t.current_level) 1272 # TODO: is_leaf/requires_grad? 1273 else: 1274 assert t.stride is not None 1275 1276 sizes = t.size 1277 strides = t.stride 1278 r = callback( 1279 lambda: torch.empty_strided( 1280 sizes, 1281 strides, 1282 dtype=t.dtype, 1283 device="meta", 1284 ) 1285 ) 1286 if self.copy_data: 1287 with torch.no_grad(), no_dispatch(): 1288 r.real_tensor = torch.empty_strided( # type: ignore[attr-defined] 1289 t.size, 1290 t.stride, 1291 dtype=t.dtype, 1292 device=t.device, 1293 ) 1294 assert t.data is not None 1295 _safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined] 1296 return r 1297 1298 r = _to_fake_tensor(t) 1299 1300 elif t.is_functional and t.device.type not in ["xla", "lazy"]: 1301 assert t.unwrapped is not None 1302 assert not t.is_functorch_wrapped # handled above 1303 unwrapped = self.meta_tensor( 1304 t.unwrapped, 1305 shape_env=shape_env, 1306 callback=callback, 1307 source=source, 1308 symbolic_context=symbolic_context, 1309 ) 1310 r = torch._to_functional_tensor(unwrapped) 1311 torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined] 1312 1313 elif t.is_view: 1314 # Construct views in two steps: recursively meta-fy their 1315 # base, and then create view(s) off that. NB: doing it 1316 # directly from storage is WRONG because this won't cause 1317 # version counters to get shared. 1318 1319 assert t.base is not None 1320 1321 base_symbolic_context = None 1322 if shape_env and symbolic_context is not None: 1323 from torch.fx.experimental.symbolic_shapes import ( 1324 StatelessSymbolicContext, 1325 ) 1326 1327 assert isinstance(symbolic_context, StatelessSymbolicContext) 1328 # NB: This should generally be set when the input is a view, 1329 # but the exception right now is for fake-ifying grads, which is 1330 # a work in progress. 1331 if symbolic_context.view_base_context is not None: 1332 base_symbolic_context = symbolic_context.view_base_context 1333 1334 base = self.meta_tensor( 1335 t.base, 1336 shape_env, 1337 callback, 1338 source=torch._dynamo.source.AttrSource(source, "_base"), 1339 symbolic_context=base_symbolic_context, 1340 ) 1341 1342 def is_c_of_r(complex_dtype, real_dtype): 1343 return ( 1344 utils.is_complex_dtype(complex_dtype) 1345 and utils.corresponding_real_dtype(complex_dtype) 1346 == real_dtype 1347 ) 1348 1349 # In some situations, MetaConverter may be called in a 1350 # context where autograd is disabled. For the _is_view 1351 # assert to pass, we have to setup the autograd view 1352 # metadata anyway. Do this by reenabling the 1353 # ADInplaceOrView key. This is kind of a hack. 1354 old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( 1355 torch._C.DispatchKey.ADInplaceOrView 1356 ) 1357 torch._C._dispatch_tls_set_dispatch_key_excluded( 1358 torch._C.DispatchKey.ADInplaceOrView, False 1359 ) 1360 try: 1361 if base.dtype == t.dtype: 1362 pass 1363 elif is_c_of_r(base.dtype, t.dtype): 1364 base = torch.view_as_real(base) 1365 elif is_c_of_r(t.dtype, base.dtype): 1366 base = torch.view_as_complex(base) 1367 else: 1368 # This is not guaranteed to succeed. If it fails, it 1369 # means there is another dtype-converting view function 1370 # that hasn't been handled here 1371 base = base.view(t.dtype) 1372 1373 # This is very tricky. Naively, you might expect this 1374 # to hold: 1375 # 1376 # if t.requires_grad and not safe_is_leaf(t) 1377 # assert t._base.requires_grad 1378 # 1379 # But it's not true! As you can see in the following 1380 # program: 1381 # 1382 # x = torch.zeros(4) 1383 # y = x.view(1, 4) 1384 # y.requires_grad = True 1385 # z = y.view(1, 1, 4) 1386 # assert z._base is x 1387 # 1388 # So we may have to do *two* views out of the base to 1389 # recreate this situation. 1390 if t.is_leaf: 1391 # Leaf views that track view metadata are created by 1392 # creating a view inside a no_grad block 1393 with torch.no_grad(): 1394 r = view_from_base(base, t) 1395 # As it's a leaf, we can directly assign requires_grad 1396 r.requires_grad = t.requires_grad 1397 else: 1398 if t.base.requires_grad == t.requires_grad: 1399 # Easy case, just run the view op 1400 with torch.enable_grad(): 1401 r = view_from_base(base, t) 1402 1403 # NB: We don't actaully faithfully replicate 1404 # autograd connectivity, but that doesn't matter 1405 # today. See following for more info: 1406 # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913 1407 else: 1408 # Obscure case. Create a leaf view and give it the 1409 # correct requires_grad, then do the final view. 1410 # NB: Can't have a non-leaf without requiring grad! 1411 assert t.requires_grad 1412 with torch.no_grad(): 1413 mid = base.view(base.shape) 1414 mid.requires_grad = t.requires_grad 1415 with torch.enable_grad(): 1416 r = view_from_base(mid, t) 1417 # The CreationMeta influences whether or not inplace 1418 # mutation is an error or not. So we need to make 1419 # sure we properly propagate this as well. 1420 assert t.creation_meta is not None 1421 torch._C._autograd._set_creation_meta(r, t.creation_meta) 1422 finally: 1423 torch._C._dispatch_tls_set_dispatch_key_excluded( 1424 torch._C.DispatchKey.ADInplaceOrView, old_exclude 1425 ) 1426 1427 else: 1428 is_leaf = t.is_leaf 1429 1430 # Graph-Break for wrapped tensors 1431 if ( 1432 not (t.is_batchedtensor or t.is_gradtrackingtensor) 1433 and t.is_functorch_wrapped 1434 ) or t.is_legacy_batchedtensor: 1435 return NotImplemented 1436 1437 ( 1438 sizes, 1439 strides, 1440 storage_offset, 1441 ) = sym_sizes_strides_storage_offset(t, source, symbolic_context) 1442 1443 # If we have a subclass that desugars into dense tensors, 1444 # perform our callback on each inner tensor. 1445 if t.is_traceable_wrapper_subclass: 1446 r = empty_create_subclass( 1447 t, outer_size=sizes, outer_stride=strides 1448 ) 1449 else: 1450 r = callback( 1451 lambda: torch.empty_strided( 1452 sizes, 1453 strides, 1454 dtype=t.dtype, 1455 device="meta", 1456 ) 1457 ) 1458 if self.copy_data: 1459 with torch.no_grad(), no_dispatch(): 1460 assert t.size is not None 1461 assert t.stride is not None 1462 r.real_tensor = torch.empty_strided( 1463 t.size, t.stride, dtype=t.dtype, device=t.device 1464 ) 1465 _safe_copy(r.real_tensor, t.data) 1466 1467 assert safe_is_leaf(r), "the callback you passed in doesn't detach" 1468 if t.requires_grad: 1469 r.requires_grad = t.requires_grad 1470 if not is_leaf: 1471 # Fake up some autograd history. 1472 # Note: we *used* to call .clone() here to mock up some autograd history. 1473 # This is bad for subclasses. 1474 # Consider the case where you have a wrapper subclass that is contiguous, 1475 # but its inner tensor is noncontiguous(). 1476 # .clone() (or other ops) will have the side effect of changing 1477 # the metadata of the inner tensor. 1478 # So instead, we now have a dedicated fn to set autograd history, 1479 # without inadvertently changing other metadata. 1480 r = torch._C._functions.DelayedError( 1481 "Internal error: Tried to backward() through example input", 1482 1, 1483 )(r) 1484 1485 s = t.storage 1486 assert s is not None 1487 if s.id not in self.storage_memo and ( 1488 r.is_nested 1489 or ( 1490 r.stride() == strides 1491 and r.storage_offset() == storage_offset 1492 ) 1493 ): 1494 # You're normal and happy, install the fresh storage into the memo 1495 self.set_storage_memo(s, r.untyped_storage()) 1496 if self.copy_data: 1497 r.untyped_storage().real_storage = ( 1498 r.real_tensor.untyped_storage() 1499 ) 1500 else: 1501 # You're in crazy town; somehow you gave us a tensor 1502 # that wasn't a view, but had nonzero storage offset, 1503 # nontrivial strides (such that clone() couldn't 1504 # preserve them), or already aliases with another 1505 # tensor's storage. The most typical way to end 1506 # up here is with set_. So use set_ to bludgeon this 1507 # in. 1508 r_s = self.meta_storage(s, callback=callback) 1509 # NB: In principle, this should always work, but there 1510 # is some subtle difference in the autograd metadata 1511 # that means we will backprop the set_ call, even if 1512 # r is declared as an input to grad. 1513 # See https://github.com/pytorch/pytorch/issues/87956 1514 # for the reproducer. 1515 # NB: The in_kernel_invocation_manager here is necessary 1516 # for fake tensor. If we run the set_ call with fake 1517 # tensor on, r will improperly report that it is NOT a 1518 # meta tensor but a cpu tensor, and then the set_ call 1519 # will fail due to device mismatch. no_dispatch() is 1520 # not enough, because the fake tensor will still claim 1521 # to be a CPU tensor and you'll end up in the CPU 1522 # kernel. Arguably this is a hack; a cleaner way to 1523 # solve this is to have a FakeStorage concept which 1524 # would report it's CPU device--no problem now! But 1525 # this is difficult to do because we don't have storage 1526 # subclasses. Relevant test is 1527 # DynamicShapesFunctionTests::test_add_dynamic_shapes in 1528 # test/dynamo/test_dynamic_shapes.py 1529 maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() 1530 from torch._subclasses.fake_tensor import ( 1531 in_kernel_invocation_manager, 1532 maybe_get_fake_mode, 1533 ) 1534 1535 mb_fake_mode = maybe_get_fake_mode(r) 1536 if mb_fake_mode is not None: 1537 maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode) 1538 with torch.no_grad(), maybe_suppress(): 1539 with maybe_fake_mgr: 1540 r.set_(r_s, storage_offset, sizes, strides) 1541 if self.copy_data: 1542 with torch.no_grad(), no_dispatch(): 1543 r.real_tensor.set_( 1544 r_s.real_storage, 1545 t.storage_offset, 1546 t.size, 1547 t.stride, 1548 ) 1549 1550 if t.grad is not None: 1551 from torch._dynamo.source import AttrSource 1552 1553 # TODO: Use a valid grad-specific symbolic context instead of recycling 1554 # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). 1555 r.grad = self.meta_tensor( 1556 t.grad, 1557 shape_env, 1558 callback, 1559 source=AttrSource(source, "grad"), 1560 symbolic_context=symbolic_context, 1561 ) 1562 torch._C._set_conj(r, t.is_conj) 1563 torch._C._set_neg(r, t.is_neg) 1564 # This can be skipped if necessary for performance reasons 1565 skip_leaf = ( 1566 t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE 1567 ) 1568 assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf) 1569 # Thanks to storage resizing, it's possible to end up with a tensor 1570 # that advertises a real size, but has a storage that actually has zero bytes. 1571 # Need to reflect this in the generated FakeTensor. 1572 if t.storage is not None and t.storage.size == 0: 1573 r.untyped_storage().resize_(0) 1574 1575 if t.is_parameter: 1576 r._is_param = True 1577 1578 # See Note: [Creating symbolic nested int] 1579 if t.nested_int is not None: 1580 r.nested_int_memo = r.fake_mode.create_symbolic_nested_int( 1581 nt_tensor_id=t.nested_int 1582 ) 1583 1584 self.set_tensor_memo(t, r) 1585 1586 return self.get_tensor_memo(t) 1587 1588 def __call__( 1589 self, 1590 t, 1591 shape_env=None, 1592 *, 1593 callback=lambda t: t(), 1594 source=None, 1595 symbolic_context=None, 1596 # Controls whether or not we should dump the tensor metadata to structured logs 1597 # when source is not None. Because we refakify after Dynamo is done, 1598 # we don't want to dump info again from AOTAutograd, it is redundant. 1599 trace=True, 1600 ): 1601 # TODO: zero tensors? We appear to have eliminated them by 1602 # excluding complex for now 1603 1604 # Filter out cases we don't support 1605 # TODO: This can probably be simplified quite a bit 1606 if isinstance(t, torch.Tensor): 1607 if ( 1608 # Lazy tensors are not supported. Note that XLA is 1609 # implemented on top of lazy tensor, not excluded here; we 1610 # have some special handling for it; this is for XLA Dynamo 1611 # integration 1612 t.device.type == "lazy" 1613 or 1614 # Quantization is not supported 1615 t.is_quantized 1616 or 1617 # Views out of sparse tensors not currently supported (plain 1618 # sparse is supported htough) 1619 (t._is_view() and t._base is not None and t._base.is_sparse) 1620 ): 1621 self.miss += 1 1622 return NotImplemented 1623 else: 1624 self.hit += 1 1625 elif torch.overrides.is_tensor_like(t): 1626 self.miss += 1 1627 return NotImplemented 1628 else: 1629 # non-Tensor types don't count as hit or miss 1630 return t 1631 1632 if source is None: 1633 trace = False 1634 1635 # Describe the tensor. NB: do NOT disable ambient modes, we may need 1636 # to query them when figuring out what to put in here 1637 t_desc = self.describer.describe_tensor(t, trace=trace) 1638 1639 if trace: 1640 trace_structured( 1641 "describe_source", 1642 metadata_fn=lambda: { 1643 "describer_id": self.describer.id, 1644 "id": t_desc.id, 1645 "source": source.name(), 1646 }, 1647 ) 1648 1649 # Do the meta-fication. Here, we disable all the ambient modes, to 1650 # better simulate what would be like to re-fakeify from a fresh 1651 # process 1652 with contextlib.ExitStack() as exit_stack: 1653 exit_stack.enter_context(torch._dispatch.python.suspend_functionalization()) 1654 st = peek_interpreter_stack() 1655 if st is not None: 1656 exit_stack.enter_context( 1657 torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() 1658 ) 1659 1660 r = self.meta_tensor( 1661 t_desc, 1662 shape_env=shape_env, 1663 callback=callback, 1664 source=source, 1665 symbolic_context=symbolic_context, 1666 ) 1667 1668 if type(t) is torch.nn.Parameter: 1669 # NB: Cannot directly use Parameter constructor 1670 # because that would force a detach, not desirable 1671 r._is_param = True 1672 1673 # TODO: return the description for later 1674 return r 1675 1676 1677import torch._prims_common as utils 1678