1# mypy: allow-untyped-decorators 2# Copyright (c) Facebook, Inc. and its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8from __future__ import annotations 9 10import functools 11import inspect 12import logging 13import operator 14import traceback 15import typing 16import typing_extensions 17import warnings 18import weakref 19from collections import defaultdict 20from contextlib import contextmanager, ExitStack, nullcontext 21from dataclasses import dataclass 22from typing import ( 23 Any, 24 Callable, 25 Dict, 26 Generator, 27 List, 28 Mapping, 29 Optional, 30 overload, 31 Protocol, 32 Sequence, 33 Tuple, 34 Type, 35 TYPE_CHECKING, 36 TypeVar, 37 Union, 38) 39from typing_extensions import Concatenate, ParamSpec, Self 40from weakref import WeakKeyDictionary 41 42import torch 43import torch._ops 44import torch.fx as fx 45import torch.fx.traceback as fx_traceback 46import torch.utils._pytree as pytree 47from torch import SymBool, SymInt, Tensor 48from torch._dispatch.python import enable_python_dispatcher 49from torch._library.fake_class_registry import FakeScriptObject 50from torch._subclasses.fake_impls import fast_detach 51from torch._subclasses.fake_tensor import ( 52 FakeTensor, 53 FakeTensorMode, 54 is_fake, 55 unset_fake_temporarily, 56) 57from torch._subclasses.meta_utils import is_sparse_any 58from torch.fx import GraphModule, Proxy, Tracer 59from torch.fx.graph_module import _assign_attr 60from torch.fx.node import _side_effectful_need_to_be_preserved_pre_dispatch 61from torch.fx.passes.shape_prop import _extract_tensor_metadata 62from torch.nn import Module 63from torch.overrides import TorchFunctionMode 64from torch.utils._python_dispatch import ( 65 _disable_infra_mode, 66 _push_mode, 67 _unset_infra_mode, 68 TorchDispatchMode, 69) 70from torch.utils._stats import count 71from torch.utils._thunk import Thunk 72from torch.utils._traceback import CapturedTraceback 73from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary 74 75from ._backward_state import BackwardState 76from .sym_node import SymNode 77 78 79if TYPE_CHECKING: 80 import types 81 from collections.abc import MutableMapping 82 83 import sympy 84 85 from torch._ops import OpOverload 86 from torch.fx._symbolic_trace import PHBase 87 from torch.types import IntLikeType 88 89__all__ = [ 90 "PythonKeyTracer", 91 "dispatch_trace", 92 "make_fx", 93 "DecompositionInterpreter", 94 "py_sym_types", 95 "get_innermost_proxy_mode", 96 "get_proxy_mode", 97 "handle_sym_dispatch", 98 "maybe_enable_thunkify", 99 "maybe_disable_thunkify", 100] 101 102_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] 103 104_AnyScriptObject = (torch.ScriptObject, FakeScriptObject) 105_AnyScriptObjectType = Union[torch.ScriptObject, FakeScriptObject] 106 107aten = torch.ops.aten 108prim = torch.ops.prim 109 110log = logging.getLogger(__name__) 111not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") 112 113CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {} 114 115CONSTANT_NUMEL_LIMIT = 1 116 117T = TypeVar("T") 118U = TypeVar("U") 119_P = ParamSpec("_P") 120R = TypeVar("R") 121 122null_ctx_type = type(nullcontext) 123# We currently convert all SymInt to proxies before we use them. 124# This could plausibly be handled at the Dynamo level. 125pytree.register_pytree_node( 126 torch.Size, 127 lambda xs: (list(xs), None), 128 lambda xs, _: tuple(xs), 129 flatten_with_keys_fn=lambda xs: ( 130 [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)], 131 None, 132 ), 133) 134 135 136def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]: 137 """FX gets confused by varargs, de-confuse it""" 138 argnames = ",".join(f"arg{i}" for i in range(nargs)) 139 return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) 140 141 142@contextmanager 143def decompose( 144 decomposition_table: Optional[Mapping[OpOverload, Callable]] 145) -> Generator[Mapping[OpOverload, Callable], None, None]: 146 global CURRENT_DECOMPOSITION_TABLE 147 old_decomposition_table = CURRENT_DECOMPOSITION_TABLE 148 CURRENT_DECOMPOSITION_TABLE = decomposition_table or {} 149 try: 150 yield CURRENT_DECOMPOSITION_TABLE 151 finally: 152 CURRENT_DECOMPOSITION_TABLE = old_decomposition_table 153 154 155# ensure we cannot collide with other properties 156proxy_slot = object() 157 158 159class _NoDefault: 160 pass 161 162 163no_default = _NoDefault() 164 165from torch.types import py_sym_types, PySymType 166 167 168class _HasMeta(Protocol): 169 meta: Dict[str, PySymType] 170 171 172def is_sym_node(node: _HasMeta) -> bool: 173 assert hasattr(node, "meta"), "All nodes traced with proxy_tensor should have meta" 174 return "val" in node.meta and isinstance(node.meta["val"], py_sym_types) 175 176 177@overload 178def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: 179 ... 180 181 182@overload 183def set_proxy_slot( 184 obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy 185) -> None: 186 ... 187 188 189@overload 190def set_proxy_slot( 191 obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType 192) -> None: 193 ... 194 195 196def set_proxy_slot( 197 obj: Union[PySymType, _AnyScriptObjectType, Tensor], 198 tracer: _ProxyTracer, 199 proxy: object, 200) -> None: 201 log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy) 202 if isinstance(obj, Tensor): 203 # We DO want to clobber proxies whenever we run an inplace operation 204 # on a tensor, and it affects the metadata on the proxy. 205 assert isinstance(proxy, _ProxyTensor) 206 tracer.tensor_tracker[obj] = proxy 207 elif isinstance(obj, (_AnyScriptObject)): 208 # We DO want to clobber proxies, with a similar rationale as for tensors. 209 assert isinstance(proxy, Proxy) 210 tracer.script_object_tracker[obj] = proxy 211 else: 212 # NB: Never clobber pre-existing proxy. Although the proxies 213 # are in principle equivalent, when we do graph partitioning 214 # we need there not to be spurious dependencies on tangent inputs. 215 # This works because primals get their SymInts set first, and 216 # THEN later we allocate tangent inputs. Make sure if a SymInt 217 # is derivable from a primal that we use that. 218 assert isinstance(obj, py_sym_types), type(obj) 219 if obj not in tracer.symnode_tracker: 220 tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy) 221 222 # WAR: python test/dynamo/test_subclasses.py 223 # TestNestedTensor.test_basic_autograd 224 # 225 # AOTAutograd doesn't pass the "outer sizes" as an actual argument 226 # to make_fx, but it is made use of internally in AOTAutograd's 227 # call to tensor unflatten. Because the outer sizes isn't passed 228 # as an argument, it is therefore untracked. However, it turns 229 # out you luck out, because *Dynamo* will manually add the outer 230 # sizes as an argument so you can fix up the proxy'ness. 231 # 232 # This is probably fixed in 233 # https://github.com/pytorch/pytorch/pull/125941/ 234 import sympy 235 236 if isinstance(obj.node.expr, sympy.Symbol): 237 tracer.sympy_expr_tracker[obj.node.expr] = proxy 238 239 240def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: 241 assert isinstance(obj, (Tensor, SymNode)), type(obj) 242 return bool(get_proxy_slot(obj, tracer, False, lambda _: True)) 243 244 245_PySymProxyType = Thunk[Proxy] 246 247 248@overload 249def get_proxy_slot( 250 obj: Tensor, 251 tracer: _ProxyTracer, 252) -> _ProxyTensor: 253 ... 254 255 256@overload 257def get_proxy_slot( 258 obj: Tensor, 259 tracer: _ProxyTracer, 260 default: U, 261) -> Union[_ProxyTensor, U]: 262 ... 263 264 265@overload 266def get_proxy_slot( 267 obj: Tensor, 268 tracer: _ProxyTracer, 269 default: U, 270 transform: Callable[[_ProxyTensor], R], 271) -> Union[R, U]: 272 ... 273 274 275@overload 276def get_proxy_slot( 277 obj: _AnyScriptObjectType, 278 tracer: _ProxyTracer, 279) -> Proxy: 280 ... 281 282 283@overload 284def get_proxy_slot( 285 obj: _AnyScriptObjectType, 286 tracer: _ProxyTracer, 287 default: U, 288) -> Union[Proxy, U]: 289 ... 290 291 292@overload 293def get_proxy_slot( 294 obj: _AnyScriptObjectType, 295 tracer: _ProxyTracer, 296 default: U, 297 transform: Callable[[Proxy], R], 298) -> Union[R, U]: 299 ... 300 301 302@overload 303def get_proxy_slot( 304 obj: PySymType, 305 tracer: _ProxyTracer, 306) -> _PySymProxyType: 307 ... 308 309 310@overload 311def get_proxy_slot( 312 obj: PySymType, 313 tracer: _ProxyTracer, 314 default: T, 315) -> Union[T, _PySymProxyType]: 316 ... 317 318 319@overload 320def get_proxy_slot( 321 obj: PySymType, 322 tracer: _ProxyTracer, 323 default: U, 324 transform: Callable[[_PySymProxyType], R], 325) -> Union[R, U]: 326 ... 327 328 329# the default argument is what to return if the slot is not set. 330# the transform argument is handy if you need to extract a subfield from 331# the successfully looked up result (but NOT the default.) 332def get_proxy_slot( 333 obj: Union[Tensor, _AnyScriptObjectType, PySymType], 334 tracer: _ProxyTracer, 335 default: object = no_default, 336 transform: Callable = lambda x: x, 337) -> object: 338 tracker: Any 339 if isinstance(obj, Tensor): 340 tracker = tracer.tensor_tracker 341 elif isinstance(obj, _AnyScriptObject): 342 tracker = tracer.script_object_tracker 343 else: 344 assert isinstance(obj, py_sym_types), type(obj) 345 tracker = tracer.symnode_tracker 346 347 if obj not in tracker: 348 # Last ditch 349 if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker: 350 value = tracer.sympy_expr_tracker[obj.node.expr] 351 else: 352 if isinstance(default, _NoDefault): 353 raise RuntimeError( 354 f"{obj} ({id(obj)})is not tracked with proxy for {tracer}" 355 ) 356 return default 357 else: 358 value = tracker[obj] 359 res = transform(value) 360 return res 361 362 363def snapshot_fake(val: Tensor) -> Optional[Tensor]: 364 # val.detach() will also eventually call fast_detach(), 365 # but this saves us a full trip into __torch_dispatch__ 366 # (snapshot_fake is called a lot) 367 if isinstance(val, FakeTensor): 368 return fast_detach(val.fake_mode, val) 369 else: 370 return val.detach() 371 372 373_ExtractValType = Optional[ 374 Union[ 375 PySymType, 376 _AnyScriptObjectType, 377 BackwardState, 378 List["_ExtractValType"], 379 Tuple["_ExtractValType", ...], 380 Dict[str, "_ExtractValType"], 381 Tensor, 382 int, 383 float, 384 bool, 385 ] 386] 387 388 389def extract_val(val: _ExtractValType) -> _ExtractValType: 390 if is_fake(val): 391 return snapshot_fake(val) 392 elif isinstance(val, py_sym_types): 393 return val 394 elif isinstance(val, _AnyScriptObject): 395 return val 396 elif isinstance(val, BackwardState): 397 return val 398 elif isinstance(val, (list, tuple)): 399 return val.__class__([extract_val(x) for x in val]) 400 elif isinstance(val, dict): 401 return {k: extract_val(v) for k, v in val.items()} 402 elif isinstance(val, Tensor): 403 if not val.is_sparse: 404 # NB: Kinda hacky, but we should try to get val as the metadata 405 # everywhere 406 # TODO: This doesn't properly track storages. A more robust 407 # approach would be to maintain a per-trace FakeTensorMode and 408 # from_real_tensor to create fake values (don't forget to 409 # snapshot_fake) 410 fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) 411 with fake_tensor_mode: 412 return torch.empty_strided( 413 val.shape, val.stride(), device=val.device, dtype=val.dtype 414 ) 415 else: 416 return None 417 elif isinstance(val, (int, float, bool)): 418 return val 419 elif val is None: 420 return None 421 422 typing_extensions.assert_never(val) 423 424 425@contextmanager 426def _enable_thunkify( 427 tracer: _ProxyTracer, *, enable: bool = True 428) -> Generator[None, None, None]: 429 """ 430 Enable thunkification inside the context manager. Thunkification prevents 431 SymNode computation from directly being traced into an FX graph; instead, 432 the compute is only added to the graph if it is actually used. This helps 433 us track SymNode compute when it is computed (since we need /something/ 434 to put in the tracker) even if it is unlikely to be used. 435 """ 436 old = tracer.enable_thunkify 437 tracer.enable_thunkify = enable 438 try: 439 yield 440 finally: 441 tracer.enable_thunkify = old 442 443 444@contextmanager 445def maybe_disable_thunkify() -> Generator[None, None, None]: 446 """Within a context, disable thunkification. See :func:`maybe_enable_thunkify` 447 for more details. This is helpful if you have a wrapper function which 448 you want to enable thunkification on, but in some segment on the inside (say, 449 the original user function), you want to disable thunkification as you know 450 it is not needed there. 451 """ 452 proxy_mode = get_proxy_mode() 453 if proxy_mode is not None: 454 with _enable_thunkify(proxy_mode.tracer, enable=False): 455 yield 456 else: 457 yield 458 459 460@contextmanager 461def maybe_enable_thunkify() -> Generator[None, None, None]: 462 """Within this context manager, if you are doing make_fx tracing, we will thunkify 463 all SymNode compute and avoid tracing it into the graph unless it is actually needed. 464 You should prefer to avoid using this as much as possible, as lazy evaluation of 465 SymNode tracing can lead to long chains of thunks which will stack overflow 466 if you evaluate them. However, this is currently sometimes necessary as there 467 are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error 468 due to insufficient tracing of SymNode computation. 469 """ 470 proxy_mode = get_proxy_mode() 471 if proxy_mode is not None: 472 with _enable_thunkify(proxy_mode.tracer): 473 yield 474 else: 475 yield 476 477 478# Note [invariants for node meta 'val'] 479# What invariants do we have for the 'val' set on the FX node? It has accurate 480# metadata... but only for metadata that exists "below" all other subsystems 481# (most notably autograd, but also vmap, functorch transforms, etc). This means 482# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad, 483# grad_fn, _base (_base actually may be set due to recursive call to 484# ADInplaceOrView, but you shouldn't rely on it.) 485def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy: 486 proxy.node.meta["val"] = extract_val(val) 487 488 with _enable_thunkify(proxy.tracer): # type: ignore[arg-type] 489 # Best effort tensor_meta setting; prefer using val! 490 if is_fake(val): 491 proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) 492 elif isinstance(val, Tensor) and not val.is_sparse: 493 proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) 494 return proxy 495 496 497def thunkify( 498 tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs 499) -> Thunk[R]: 500 """ 501 Delays computation of f until it's called again 502 Also caches the result 503 """ 504 if tracer.enable_thunkify: 505 return Thunk(functools.partial(f, *args, **kwargs)) 506 else: 507 r = f(*args, **kwargs) 508 return Thunk(lambda: r) 509 510 511def track_tensor( 512 tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer 513) -> None: 514 def try_set_proxy_slot( 515 outer_s: IntLikeType, 516 proxy_callable: Callable[Concatenate[PySymType, _P], Proxy], 517 *args: _P.args, 518 **kwargs: _P.kwargs, 519 ) -> None: 520 assert callable(proxy_callable) 521 if isinstance(outer_s, SymInt): 522 with _enable_thunkify(tracer): 523 set_proxy_slot( 524 outer_s, 525 tracer, 526 thunkify(tracer, proxy_callable, outer_s, *args, **kwargs), 527 ) 528 529 # The basic idea is that we need to associate each tensor/SymInt 530 # with a Proxy. How do we setup this association? We just store 531 # the proxy on the proxy slot of the object, keyed on the tracer 532 # (so that if we have multiple tracers at the same time, they 533 # don't clobber each other.) 534 for i, s in enumerate(tensor.shape): 535 try_set_proxy_slot( 536 s, 537 lambda x, i: set_meta( 538 tracer.create_proxy( 539 "call_function", torch.ops.aten.sym_size.int, (proxy, i), {} 540 ), 541 x, 542 ), 543 i, 544 ) 545 546 if not is_sparse_any(tensor): 547 for i, s in enumerate(tensor.stride()): 548 try_set_proxy_slot( 549 s, 550 lambda x, i: set_meta( 551 tracer.create_proxy( 552 "call_function", torch.ops.aten.sym_stride.int, (proxy, i), {} 553 ), 554 x, 555 ), 556 i, 557 ) 558 559 try_set_proxy_slot( 560 tensor.numel(), 561 lambda x: set_meta( 562 tracer.create_proxy( 563 "call_function", torch.ops.aten.sym_numel.default, (proxy,), {} 564 ), 565 x, 566 ), 567 ) 568 if not is_sparse_any(tensor): 569 try_set_proxy_slot( 570 tensor.storage_offset(), 571 lambda x: set_meta( 572 tracer.create_proxy( 573 "call_function", 574 torch.ops.aten.sym_storage_offset.default, 575 (proxy,), 576 {}, 577 ), 578 x, 579 ), 580 ) 581 set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) 582 583 584_NestedProxys = Union[ 585 Proxy, Sequence["_NestedProxys"], Mapping[object, "_NestedProxys"] 586] 587_NestedTensors = Union[ 588 Tensor, Sequence["_NestedTensors"], Mapping[object, "_NestedTensors"] 589] 590 591 592def track_tensor_tree( 593 inner_res: T, 594 proxy_res: _NestedProxys, 595 *, 596 constant: Optional[_NestedTensors], 597 tracer: _ProxyTracer, 598) -> T: 599 # NB: We call set_unbacked_bindings only on the *topmost* call to 600 # track_tensor_tree, not recursive calls. This is because there must 601 # be only ONE unbacked_binding proxy call, and it should be the one 602 # where all of the unbacked SymInts actually first come into existence. 603 # If you call this again on the inner proxies for the tuple projections, 604 # you will have multiple unbacked_bindings for the same symbol, but 605 # they're not going to show up anywhere. 606 # 607 # I was briefly deceived into setting unbacked bindings recursively when 608 # working on https://github.com/pytorch/pytorch/pull/133585 because I 609 # observed that some extra unbacked bindings were needed to handle some 610 # higher order operator code. But actually it looks like this was 611 # just an unrelated bug that needed to be fixed separately. 612 _set_unbacked_bindings(inner_res, proxy_res) 613 614 def wrap_with_proxy( 615 e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors] 616 ) -> None: 617 if isinstance(e, Tensor): 618 assert isinstance(proxy, Proxy) 619 assert constant is None or isinstance(constant, Tensor) 620 track_tensor(e, proxy, tracer=tracer, constant=constant) 621 set_meta(proxy, e) 622 elif isinstance(e, py_sym_types): 623 assert isinstance(proxy, Proxy) 624 # NB: eagerly set meta here, so that the numbering is in order 625 set_meta(proxy, e) 626 set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) 627 elif isinstance(e, _AnyScriptObject): 628 assert isinstance(proxy, Proxy) 629 set_proxy_slot(e, tracer, proxy) 630 set_meta(proxy, e) 631 elif isinstance(e, (tuple, list)): 632 # example use case: allreduce_ returns ([tensor], work) 633 if isinstance(proxy, fx.Proxy): 634 set_meta(proxy, e) 635 636 def get_constant( 637 c: Optional[_NestedTensors], idx: int 638 ) -> Optional[_NestedTensors]: 639 if c is None: 640 return None 641 else: 642 assert isinstance(c, (list, tuple)) 643 return c[idx] 644 645 for idx, ee in enumerate(e): 646 # Use an indexer here - if proxy is a List then it will unwrap 647 # it. If it's a Proxy then it will proxy the getelem. 648 wrap_with_proxy(ee, proxy[idx], get_constant(constant, idx)) # type: ignore[index] 649 650 elif isinstance(e, dict): 651 # example use case: triton_kernel_wrapper takes arguments as kwargs 652 653 # In theory we could support const-prop when proxy-tensor-tracing 654 # operators that returns dicts of tensors, but we have no use case 655 # for it today (since the only op we currently trace that can 656 # return a dict is triton_kernel_wrapper_functional/mutation, 657 # which does not participate in const-prop) 658 assert constant is None 659 660 if isinstance(proxy, fx.Proxy): 661 set_meta(proxy, e) 662 663 for key, val in e.items(): 664 wrap_with_proxy(val, proxy[key], None) # type: ignore[index] 665 666 elif isinstance(e, BackwardState): 667 assert isinstance(proxy, Proxy) 668 set_meta(proxy, e) 669 e.proxy = proxy 670 else: 671 # intentionally pass on primitives 672 pass 673 674 wrap_with_proxy(inner_res, proxy_res, constant) 675 676 return inner_res 677 678 679@dataclass 680class _ProxyTensor: 681 proxy: Proxy 682 constant: Optional[Tensor] 683 684 685def fetch_sym_proxy( 686 tracer: _ProxyTracer, 687) -> Callable[[PySymType], Union[bool, int, float, Proxy]]: 688 def inner(e: PySymType) -> Union[int, bool, float, Proxy]: 689 n = e.node 690 if n.constant is not None: 691 return n.constant 692 if e.node.expr.is_number: 693 if isinstance(e, SymBool): 694 return bool(e.node.expr) 695 elif isinstance(e, SymInt): 696 return int(e.node.expr) 697 return float(e.node.expr) 698 else: 699 assert isinstance(e, py_sym_types) 700 # NB: we REQUIRE all symints to be tracked 701 return get_proxy_slot(e, tracer).force() 702 703 return inner 704 705 706@overload 707def fetch_object_proxy(tracer: _ProxyTracer, t: Tensor) -> Union[_ProxyTensor, Tensor]: 708 ... 709 710 711@overload 712def fetch_object_proxy( 713 tracer: _ProxyTracer, t: _AnyScriptObjectType 714) -> Union[Proxy, _AnyScriptObjectType]: 715 ... 716 717 718@overload 719def fetch_object_proxy( 720 tracer: _ProxyTracer, t: PySymType 721) -> Union[_PySymProxyType, PySymType]: 722 ... 723 724 725def fetch_object_proxy( 726 tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType] 727) -> object: 728 return get_proxy_slot(t, tracer, t) 729 730 731HANDLED_TYPES = (Tensor, torch.nn.Parameter, FakeTensor) 732 733 734def _maybe_record_pointwise_barrier( 735 func: object, proxy_mode: ProxyTorchDispatchMode 736) -> None: 737 """ 738 Records pointwise operators in user program (non decomposed) that were output in fp16/bf16 739 """ 740 if proxy_mode.decomp_layers or not proxy_mode.emulate_precision_casts: 741 return 742 743 if ( 744 not isinstance(func, torch._ops.OpOverload) 745 or torch.Tag.pointwise not in func.tags 746 ): 747 return 748 749 last_node = next(iter(reversed(proxy_mode.tracer.graph.nodes))) 750 t = last_node.meta.get("val") 751 if not isinstance(t, torch.Tensor) or t.dtype not in ( 752 torch.bfloat16, 753 torch.float16, 754 ): 755 return 756 757 last_node.meta["low_precision_pointwise_barrier"] = True 758 759 760def proxy_call( 761 proxy_mode: ProxyTorchDispatchMode, 762 func: OpOverload, 763 pre_dispatch: bool, 764 args: Tuple[object, ...], 765 kwargs: Dict[str, object], 766) -> object: 767 unrecognized_types: List[Type] = [] 768 flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs)) 769 770 def can_handle_tensor(x: Tensor) -> bool: 771 r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) 772 if proxy_mode._allow_fake_constant: 773 r = r or type(x) in (torch._subclasses.FakeTensor,) 774 if not r: 775 unrecognized_types.append(type(x)) 776 return r 777 778 # If there are any tensor subclasses, we need to handle those tensor subclasses first 779 # TODO: we could use types to test this 780 if not all(can_handle_tensor(x) for x in flat_args_kwargs if isinstance(x, Tensor)): 781 not_implemented_log.debug( 782 "ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", 783 unrecognized_types, 784 ) 785 return NotImplemented 786 787 r = maybe_handle_decomp(proxy_mode, func, args, kwargs) 788 if r is not NotImplemented: 789 _maybe_record_pointwise_barrier(func, proxy_mode) 790 return r 791 792 # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. 793 if not pre_dispatch and func not in [ 794 torch.ops.aten.size.default, 795 torch.ops.aten.stride.default, 796 torch.ops.aten.storage_offset.default, 797 ]: 798 with proxy_mode: 799 r = func.decompose(*args, **kwargs) 800 if r is not NotImplemented: 801 return r 802 803 tracer = proxy_mode.tracer 804 f_flat_args_kwargs = [ 805 ( 806 fetch_object_proxy(tracer, x) 807 if isinstance(x, (Tensor, _AnyScriptObject)) 808 else x 809 ) 810 for x in flat_args_kwargs 811 ] 812 813 # If there are SymInts, we also should not consider this constant. 814 # However, fake tensor handling of SymInts is sufficiently broken that 815 # I couldn't write a test for this case 816 all_constant = ( 817 not any( 818 t.constant is None 819 for t in f_flat_args_kwargs 820 if isinstance(t, _ProxyTensor) 821 ) 822 # TODO: maybe constant SymInts should also be allowed? Not sure if 823 # this can happen 824 and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs) 825 ) 826 827 if torch.Tag.data_dependent_output in func.tags: 828 # Check if all of the Tensor inputs are constants 829 if all_constant: 830 const_flat_args_kwargs = [ 831 t.constant if isinstance(t, _ProxyTensor) else t 832 for t in f_flat_args_kwargs 833 ] 834 const_args, const_kwargs = pytree.tree_unflatten( 835 const_flat_args_kwargs, spec 836 ) 837 with unset_fake_temporarily(): 838 return func(*const_args, **const_kwargs) 839 # If any of the Tensor inputs are "real" (not FakeTensor), we may 840 # incorrectly burn in constants by allowing this access. Raise 841 # an error in this case 842 if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only( 843 Tensor, lambda t: not is_fake(t), (args, kwargs) 844 ): 845 raise RuntimeError( 846 f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! " 847 "It's likely that this is caused by data-dependent control flow or similar. " 848 "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' " 849 "in your make_fx call." 850 ) 851 852 proxy_flat_args_kwargs = [ 853 e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs 854 ] 855 proxy_flat_args_kwargs = [ 856 (fetch_sym_proxy(proxy_mode.tracer)(e) if isinstance(e, py_sym_types) else e) 857 for e in proxy_flat_args_kwargs 858 ] 859 proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec) 860 861 # When we trace through a torch.tensor invocation, you never actually 862 # see a torch.ops.aten.tensor call. Instead, the way this function is 863 # implemented internally is that we allocate a plain tensor (this is 864 # *guaranteed* to be a plain tensor, we disable all modes when doing 865 # so), and then call at::lift_fresh on it (to give modes a chance to do 866 # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed 867 # to be freshly allocated, so we want lift_fresh to be a no-op (directly 868 # returning the input argument). 869 # 870 # Here is the basic problem: when we trace this sequence of executions 871 # into an FX graph, what happens to this call sequence? Traditionally, 872 # tensor constants get interned as buffers on the FX GraphModule. But 873 # this is dangerous. Consider: 874 # 875 # x = torch.tensor(1) 876 # x.add_(2) 877 # 878 # Naively, this traces into: 879 # 880 # t = self._tensor_constant0 # initialized to torch.tensor(1) 881 # x = torch.ops.aten.lift_fresh(t) 882 # x.add_(2) 883 # 884 # If lift_fresh returns t directly, the subsequent add_ call will 885 # modify the tensor constant. Really, the problem is we've violated 886 # the invariant the argument to lift is fresh. So what we should 887 # preserve the invariant by replacing lift_fresh with lift_fresh_copy: 888 # 889 # t = self._tensor_constant0 # initialized to torch.tensor(1) 890 # x = torch.ops.aten.lift_fresh_copy(t) 891 # x.add_(2) 892 # 893 # This is what the overload modification does. 894 if func is torch.ops.aten.lift_fresh.default: 895 func = torch.ops.aten.lift_fresh_copy.default 896 897 proxy_out = proxy_mode.tracer.create_proxy( 898 "call_function", 899 func, 900 proxy_args, 901 proxy_kwargs, 902 name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__), 903 ) 904 905 with _enable_thunkify(proxy_mode.tracer): 906 out = func(*args, **kwargs) 907 908 # In some circumstances, we will be tracing in a situation where a tensor 909 # is *statically* known to be a constant (currently, this only happens if 910 # you run torch.tensor; deterministic factory functions like torch.arange 911 # don't get this treatment). When the tensor in question is small, it's 912 # helpful to due constant propagation in case we call item() (in which 913 # case we can return the constant value that is known, rather than give 914 # an error.) The logic here tests if constant propagation is possible 915 # (because all of the inputs are constant). If so, we disable fake tensor 916 # mode (if it is on) and do true compute on the constant. 917 # 918 # It's worth highlighting that we're making a policy decision here. 919 # There is a potential that the tensor is actually quite large, and we 920 # don't actually want to run the compute. The tensor being quite large 921 # is one of the reasons why factory functions don't get this treatment 922 # (since they can be quite large; if a parameter is initialized to a 923 # constant value it will be!) Similarly, there is also a potential 924 # to run an operator that blows up the size of a small tensor; we don't 925 # protect against this case, but we could force, e.g., only single 926 # element constant computation by testing the numel of the result before 927 # propagating const-ness. Similarly, we don't require the constant to 928 # live on CPU, but we could. 929 any_constant = any( 930 t.constant is not None 931 for t in f_flat_args_kwargs 932 if isinstance(t, _ProxyTensor) 933 ) 934 935 constant = None 936 937 def tensor_numel_in_limit(t: Tensor) -> bool: 938 return t.numel() <= CONSTANT_NUMEL_LIMIT 939 940 # If this is a lift, the input tensor is guaranteed to be a 941 # constant, so we keep a copy of the original argument along so 942 # we can query it if we're asked to item() it at some later point 943 if ( 944 func is torch.ops.aten.lift_fresh_copy.default 945 and out.numel() <= CONSTANT_NUMEL_LIMIT 946 ): 947 with unset_fake_temporarily(): 948 assert isinstance(args[0], (Proxy, Tensor)), type(args[0]) 949 constant = args[0].clone() 950 elif ( 951 torch.Tag.nondeterministic_seeded not in func.tags 952 and all_constant 953 and any_constant 954 and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out) 955 ): 956 # NB: do NOT include factories as constants 957 with unset_fake_temporarily(): 958 const_flat_args_kwargs = [ 959 t.constant if isinstance(t, _ProxyTensor) else t 960 for t in f_flat_args_kwargs 961 ] 962 const_args, const_kwargs = pytree.tree_unflatten( 963 const_flat_args_kwargs, spec 964 ) 965 constant = func(*const_args, **const_kwargs) 966 else: 967 constant = None 968 969 track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) 970 _maybe_record_pointwise_barrier(func, proxy_mode) 971 return out 972 973 974class _SymNodeDict: 975 """ 976 Wrapper around a dictionary that will hash SymInts with their nodes 977 """ 978 979 def __init__(self) -> None: 980 self.sym_node_dict: Dict[PySymType, _PySymProxyType] = {} 981 982 def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None: 983 self.sym_node_dict[key.node] = value 984 985 def __getitem__(self, key: PySymType) -> _PySymProxyType: 986 return self.sym_node_dict[key.node] 987 988 def __contains__(self, key: PySymType) -> bool: 989 return key.node in self.sym_node_dict 990 991 def get( 992 self, key: PySymType, default: Optional[_PySymProxyType] = None 993 ) -> _PySymProxyType: 994 # dict.get()'s annotation doesn't accept `None` when the value type 995 # isn't Optional. 996 return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type] 997 998 def __iter__(self) -> Any: 999 raise NotImplementedError 1000 1001 def __len__(self) -> int: 1002 return len(self.sym_node_dict) 1003 1004 1005class PythonKeyTracer(Tracer): 1006 script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] 1007 symnode_tracker: _SymNodeDict 1008 sympy_expr_tracker: Dict[sympy.Symbol, object] 1009 tensor_tracker: MutableMapping[Tensor, _ProxyTensor] 1010 torch_fn_counts: Dict[OpOverload, int] 1011 enable_thunkify: bool = False 1012 1013 def __init__(self) -> None: 1014 super().__init__(autowrap_modules=()) # type: ignore[arg-type] 1015 self.tensor_tracker = WeakTensorKeyDictionary() 1016 self.symnode_tracker = _SymNodeDict() 1017 self.script_object_tracker = WeakIdKeyDictionary( 1018 dict=None, ref_type=_WeakHashRef 1019 ) 1020 self.sympy_expr_tracker = dict() 1021 1022 # Stores the torch function that was called during tracing 1023 self.torch_fn_metadata = None 1024 # Stores the counts for every torch function called. This is to help 1025 # distinguish between different calls to the same torch function. 1026 self.torch_fn_counts = {} 1027 self.enable_thunkify = False 1028 1029 # In general, we don't want to make modules leaves. In principle, users of 1030 # this tracer might want to override this in order to turn a couple specific 1031 # modules into leaves in the traced graph. 1032 def call_module( 1033 self, 1034 m: Module, 1035 forward: Callable[..., Any], 1036 args: Tuple[Any, ...], 1037 kwargs: Dict[str, Any], 1038 ) -> Any: 1039 return forward(*args, **kwargs) 1040 1041 # We don't want to turn getattr calls into proxies. So we just return the actual value. 1042 def getattr( 1043 self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy] 1044 ) -> object: 1045 return attr_val 1046 1047 def create_arg(self, a: object) -> fx.node.Node: 1048 if isinstance(a, torch.nn.Parameter): 1049 for n, p in self.root.named_parameters(): 1050 if a is p: 1051 return self.create_node("get_attr", n, (), {}) 1052 1053 qualname = self.get_fresh_qualname("_param_constant") 1054 setattr(self.root, qualname, a) 1055 1056 return self.create_node("get_attr", qualname, (), {}) 1057 elif isinstance(a, py_sym_types): 1058 assert a.node.constant is not None 1059 return a.node.constant 1060 return super().create_arg(a) # type: ignore[return-value] 1061 1062 @overload 1063 def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: 1064 ... 1065 1066 @overload 1067 def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: 1068 ... 1069 1070 @overload 1071 def unwrap_proxy( 1072 self, e: _AnyScriptObjectType 1073 ) -> Union[Proxy, _AnyScriptObjectType]: 1074 ... 1075 1076 def unwrap_proxy(self, e: T) -> object: 1077 if isinstance(e, Tensor): 1078 return get_proxy_slot(e, self, e, lambda x: x.proxy) 1079 elif isinstance(e, py_sym_types): 1080 return get_proxy_slot(e, self, e, lambda e: e.force()) 1081 elif isinstance(e, _AnyScriptObject): 1082 return get_proxy_slot(e, self, e) 1083 else: 1084 return e 1085 1086 1087@contextmanager 1088def _temp_remove_pre_dispatch_torch_function_mode() -> Generator[None, None, None]: 1089 from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode 1090 1091 temp_elements = [] 1092 pre_dispatch_mode = None 1093 1094 while _len_torch_function_stack() > 0: 1095 mode = _pop_mode() 1096 if isinstance(mode, PreDispatchTorchFunctionMode): 1097 pre_dispatch_mode = mode 1098 break 1099 else: 1100 temp_elements.append(mode) 1101 1102 for mode in reversed(temp_elements): 1103 _push_mode(mode) 1104 1105 try: 1106 yield 1107 1108 finally: 1109 if pre_dispatch_mode is not None: 1110 count = len(temp_elements) 1111 while count > 0: 1112 mode = _pop_mode() 1113 count -= 1 1114 1115 temp_elements.append(pre_dispatch_mode) 1116 1117 for mode in reversed(temp_elements): 1118 _push_mode(mode) 1119 1120 1121@torch._disable_dynamo 1122def dispatch_trace( 1123 root: Union[Module, Callable], 1124 tracer: Tracer, 1125 concrete_args: Optional[Tuple[Any, ...]] = None, 1126) -> GraphModule: 1127 graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] 1128 1129 # NB: be careful not to DCE .item() calls 1130 def impure_pred(n: fx.Node) -> bool: 1131 from .symbolic_shapes import is_accessor_node 1132 1133 # Always defer to the built-in notion of impure 1134 if n.is_impure(): 1135 return True 1136 1137 # Accessors always OK to DCE 1138 if is_accessor_node(n): 1139 return False 1140 1141 # If the operator in question takes SymInt args to SymInt output, 1142 # we assume it's pure and OK to DCE 1143 if ( 1144 isinstance(n.meta.get("val"), py_sym_types) 1145 and 1146 # NB: constant args ok 1147 all( 1148 isinstance(a.meta.get("val"), py_sym_types) 1149 for a in n.args 1150 if isinstance(a, fx.Node) 1151 ) 1152 ): 1153 return False 1154 1155 # No idea, just assume it's not OK 1156 return True 1157 1158 graph.eliminate_dead_code(impure_pred) 1159 from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints 1160 1161 dedupe_symints(graph) 1162 name = root.__class__.__name__ if isinstance(root, Module) else root.__name__ 1163 return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name) 1164 1165 1166def wrap_key( 1167 f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dispatch: bool 1168) -> Callable[_P, R]: 1169 flat_tensors, tensors_spec = pytree.tree_flatten(tensors) 1170 1171 @functools.wraps(f) 1172 def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R: 1173 flat_proxies, proxies_spec = pytree.tree_flatten(proxies) 1174 assert len(flat_proxies) == len(flat_tensors) 1175 with disable_proxy_modes_tracing() as m: 1176 assert isinstance(m, ProxyTorchDispatchMode) 1177 track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) 1178 1179 def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]: 1180 return get_proxy_slot(t, tracer, t, lambda x: x.proxy) 1181 1182 out = f(*tensors) 1183 out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out) 1184 out = pytree.tree_map_only( 1185 _AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out 1186 ) 1187 1188 def get_sym_proxy_slot(t: PySymType) -> Proxy: 1189 return get_proxy_slot(t, tracer).force() 1190 1191 out = pytree.tree_map_only(py_sym_types, get_sym_proxy_slot, out) 1192 return out 1193 1194 return wrapped 1195 1196 1197# TODO: Make downstream users of this work with OperatorBase 1198ORIGINAL_ATEN: Optional[object] = None 1199 1200 1201@contextmanager 1202def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]: 1203 global ORIGINAL_ATEN 1204 if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta(): 1205 ORIGINAL_ATEN = func 1206 fx_traceback.current_meta["original_aten"] = func 1207 try: 1208 yield 1209 finally: 1210 ORIGINAL_ATEN = None 1211 fx_traceback.current_meta["original_aten"] = None 1212 else: 1213 yield 1214 1215 1216class TorchFunctionMetadataMode(TorchFunctionMode): 1217 def __init__(self, tracer: _ProxyTracer) -> None: 1218 self.tracer = tracer 1219 1220 def __torch_function__( 1221 self, 1222 func: OpOverload, 1223 types: Tuple[torch._C._TensorMeta, ...], 1224 args: Tuple[object, ...] = (), 1225 kwargs: Optional[Dict[str, object]] = None, 1226 ) -> object: 1227 kwargs = kwargs or {} 1228 self.tracer.torch_fn_metadata = func 1229 self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1 1230 return func(*args, **kwargs) 1231 1232 1233# This mode is **only** used for pre_dispatch tracing. 1234# In particular, we need to make sure that autograd/autocast API's 1235# that do not desugar into dispatcher operators stay in the graph. 1236class PreDispatchTorchFunctionMode(TorchFunctionMode): 1237 def __init__(self, tracer: _ProxyTracer) -> None: 1238 self.tracer = tracer 1239 1240 def __torch_function__( 1241 self, 1242 func: OpOverload, 1243 types: Tuple[torch._C._TensorMeta, ...], 1244 args: Tuple[object, ...] = (), 1245 kwargs: Optional[Dict[str, object]] = None, 1246 ) -> object: 1247 kwargs = kwargs or {} 1248 if func in _side_effectful_need_to_be_preserved_pre_dispatch: 1249 # It's for passing the export verifier which needs to verify the meta['val'] 1250 # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, 1251 # instead of hardcoding it here. 1252 node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type] 1253 if func is torch._C._set_grad_enabled: 1254 node.meta["val"] = None 1255 return node 1256 # Don't actually run the function! We just want to trace the calls 1257 # into a graph. We don't actualy want to change global autograd state. 1258 return func(*args, **kwargs) 1259 1260 1261class ProxyTorchDispatchMode(TorchDispatchMode): 1262 # Ensure this is read-only; this exists only for legacy reasons 1263 @property 1264 def enable_tracing(self) -> bool: 1265 return True 1266 1267 def __init__( 1268 self, 1269 tracer: _ProxyTracer, 1270 tracing_mode: str, 1271 pre_dispatch: bool = False, 1272 _allow_fake_constant: bool = False, 1273 _error_on_data_dependent_ops: bool = True, 1274 ) -> None: 1275 dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None 1276 super().__init__(dk) 1277 self.tracer = tracer 1278 self.tracing_mode = tracing_mode 1279 self.pre_dispatch = pre_dispatch 1280 self._allow_fake_constant = _allow_fake_constant 1281 self._error_on_data_dependent_ops = _error_on_data_dependent_ops 1282 # Indicates to our torch_dispatch dispatching infra that 1283 # this is an "infra" mode with lower dispatching precedence. 1284 self._mode_key = torch._C._TorchDispatchModeKey.PROXY 1285 # Every time we enter a mode, we maintain a stack telling us what the previous 1286 # ProxyTorchDispatchMode state was (if there was any). 1287 # This lets us properly reset the state on exit. 1288 self.enter_stack: List[Optional[ProxyTorchDispatchMode]] = [] 1289 self.decomp_layers = 0 1290 from torch._inductor import config 1291 1292 self.emulate_precision_casts = config.emulate_precision_casts 1293 1294 @count 1295 def __torch_dispatch__( 1296 self, 1297 func: OpOverload, 1298 types: Tuple[torch._C._TensorMeta, ...], 1299 args: Tuple[object, ...] = (), 1300 kwargs: Optional[Dict[str, object]] = None, 1301 ) -> object: 1302 with set_original_aten_op(func): 1303 kwargs = kwargs or {} 1304 1305 if func in (prim.device.default,): 1306 return func(*args, **kwargs) 1307 1308 return proxy_call(self, func, self.pre_dispatch, args, kwargs) 1309 1310 def __enter__(self) -> Self: 1311 # Stash and store the previous proxy mode (there may or may not be one) 1312 maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY) 1313 self.enter_stack.append(maybe_prev_proxy_mode) 1314 return super().__enter__() 1315 1316 def __exit__( 1317 self, 1318 exc_type: Optional[Type[BaseException]], 1319 exc_value: Optional[BaseException], 1320 traceback: Optional[types.TracebackType], 1321 ) -> Optional[bool]: 1322 b = super().__exit__(exc_type, exc_value, traceback) 1323 1324 # Re-enable the previous proxy mode, if there was one. 1325 mb_previous_proxy_mode = self.enter_stack.pop() 1326 if mb_previous_proxy_mode is not None: 1327 _push_mode(mb_previous_proxy_mode) 1328 1329 return b 1330 1331 @classmethod 1332 def is_infra_mode(cls) -> bool: 1333 return True 1334 1335 def _compute_proxy( 1336 self, func: OpOverload, args: Tuple[object, ...], out: PySymType 1337 ) -> Proxy: 1338 n_args = tuple( 1339 get_proxy_slot(a, self.tracer).force().node 1340 if isinstance(a, py_sym_types) 1341 else a 1342 for a in args 1343 ) 1344 1345 # func doesn't have a __torch_function__ that Proxy can interpose, so 1346 # we gotta do it manually 1347 n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type] 1348 p_out = fx.Proxy(n_out, self.tracer) 1349 set_meta(p_out, out) 1350 return p_out 1351 1352 def __sym_dispatch__( 1353 self, 1354 func: OpOverload, 1355 types: Tuple[torch._C._TensorMeta, ...], 1356 args: Tuple[object, ...], 1357 kwargs: Dict[str, object], 1358 ) -> object: 1359 # Peephole optimize multiply by one 1360 # NB: be careful not to trigger guards here! 1361 if func == operator.mul: 1362 if isinstance(args[1], int) and args[1] == 1: 1363 return args[0] 1364 elif isinstance(args[0], int) and args[0] == 1: 1365 return args[1] 1366 1367 # For speed, we assume there are no nested data structures 1368 # (otherwise we could use tree_map) 1369 # We also assume there are no keyword arguments. 1370 assert not kwargs 1371 out = func(*args, **kwargs) 1372 1373 # If func returned a constant, we don't need to trace; we have 1374 # determined that the result is constant (no matter if the inputs 1375 # were symbolic) and it is no longer necessary to trace the 1376 # computation. This could occur if func triggered some guards. 1377 if isinstance(out, py_sym_types): 1378 p_out_thunk = thunkify( 1379 self.tracer, self._compute_proxy, func=func, args=args, out=out 1380 ) 1381 set_proxy_slot(out, self.tracer, p_out_thunk) 1382 1383 return out 1384 1385 1386class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): 1387 script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] 1388 symnode_tracker: MutableMapping[PySymType, _PySymProxyType] 1389 tensor_tracker: MutableMapping[Tensor, _ProxyTensor] 1390 sympy_expr_tracker: Dict[sympy.Symbol, object] 1391 torch_fn_metadata: Optional[OpOverload] 1392 torch_fn_counts: Dict[OpOverload, int] 1393 enable_thunkify: bool = False 1394 1395 def __init__(self, graph: fx.graph.Graph) -> None: 1396 super().__init__(graph) 1397 self.symnode_tracker = weakref.WeakKeyDictionary() 1398 self.tensor_tracker = WeakTensorKeyDictionary() 1399 self.sympy_expr_tracker = {} 1400 self.script_object_tracker = WeakIdKeyDictionary( 1401 dict=None, ref_type=_WeakHashRef 1402 ) 1403 # Stores the torch function that was called during tracing 1404 self.torch_fn_metadata = None 1405 # Stores the counts for every torch function called. This is to help 1406 # distinguish between different calls to the same torch function. 1407 self.torch_fn_counts = {} 1408 1409 1410# TODO: I'm not sure what the point of this class is; you can just 1411# make_fx through a regular Interpreter 1412class DecompositionInterpreter(fx.Interpreter): 1413 def __init__( 1414 self, 1415 module: fx.GraphModule, 1416 new_graph: fx.Graph, 1417 decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, 1418 **kwargs: object, 1419 ) -> None: 1420 super().__init__(module, **kwargs) # type: ignore[arg-type] 1421 self.new_graph = new_graph 1422 self.tracer = _GraphAppendingTracerEx(self.new_graph) 1423 # Blegh 1424 self.decomposition_table = decomposition_table or {} 1425 self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") 1426 1427 def placeholder( 1428 self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] 1429 ) -> object: 1430 out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] 1431 proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) 1432 track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) 1433 # TODO handle case where the first character of target is '*' 1434 return out 1435 1436 def get_attr( 1437 self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] 1438 ) -> object: 1439 out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] 1440 proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) 1441 track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) 1442 return out 1443 1444 # call_function, call_method, call_module get traced automatically by the outer mode. 1445 1446 def output( 1447 self, target: str, args: Tuple[object, ...], kwargs: Dict[str, object] # type: ignore[override] 1448 ) -> object: 1449 out = super().output(target, args, kwargs) # type: ignore[arg-type] 1450 1451 def get_proxy_node(x: _ProxyTensor) -> fx.node.Node: 1452 return x.proxy.node 1453 1454 def unwrap(e: Tensor) -> Union[Tensor, fx.Node]: 1455 return get_proxy_slot(e, self.tracer, e, get_proxy_node) 1456 1457 self.new_graph.output(pytree.tree_map(unwrap, out)) 1458 return out 1459 1460 def run(self, *args: object, **kwargs: object) -> object: 1461 # Should enter the mode at least once for being able to restore it later 1462 # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 1463 with decompose(self.decomposition_table), self.mode: 1464 return super().run(*args, **kwargs) # type: ignore[arg-type] 1465 1466 1467def wrapper_and_args_for_make_fx( 1468 func: Callable[..., R], args: Tuple[object, ...], kwargs: Dict[str, object] 1469) -> Tuple[Callable[[List[object]], R], List[object]]: 1470 # make_fx doesn't support kwargs, so we need to do this flattening 1471 # and then unflatten the args before calling func 1472 flat_args, spec = pytree.tree_flatten((args, kwargs)) 1473 1474 def wrapped(flat_args: List[object]) -> R: 1475 fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) 1476 return func(*fn_args, **fn_kwargs) 1477 1478 return wrapped, flat_args 1479 1480 1481@contextmanager 1482def disable_autocast_cache() -> Generator[None, None, None]: 1483 old_value = torch.is_autocast_cache_enabled() 1484 torch.set_autocast_cache_enabled(False) 1485 try: 1486 yield 1487 finally: 1488 torch.set_autocast_cache_enabled(old_value) 1489 1490 1491class _ModuleNotInstalledAsSubmoduleError(NameError): 1492 pass 1493 1494 1495# Base class for inline _ModuleStackTracer.__init__.AttrProxy 1496class _AttrProxy: 1497 def reset_proxy_mapping(self, base: Module, path: str) -> None: 1498 pass 1499 1500 1501class _ModuleStackTracer(PythonKeyTracer): 1502 r"""Customized version of PythonKeyTracer that retains module stack 1503 information in node.meta["nn_module_stack"]. 1504 1505 FX symbolic trace actually does this already, but it relies on `self.root` 1506 being the actual module being traced. Since make_fx traces a lambda of our 1507 creation, things don't work properly. 1508 1509 So for this version we hold onto a reference to the original module 1510 (scope_root) and use that to match the path. Also when we see, 1511 A 1512 / \ 1513 B C 1514 \ / 1515 D 1516 we want to record the path as A.B.D by recording only one path. 1517 See Note [Preserving the nn module stack metadata during export non-strict mode] # noqa: W605 1518 """ 1519 1520 def __init__(self, scope_root: GraphModule) -> None: 1521 super().__init__() 1522 self.scope_root = scope_root 1523 self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary() 1524 self.attr_proxy_map: WeakKeyDictionary[Module, _AttrProxy] = WeakKeyDictionary() 1525 self.proxy_modules: WeakKeyDictionary[_AttrProxy, Module] = WeakKeyDictionary() 1526 self.counter = 0 1527 1528 self.module_id_cache = defaultdict(list) 1529 for name, mod in self.scope_root.named_modules(remove_duplicate=False): 1530 self.module_id_cache[id(mod)].append(name) 1531 1532 # Build a wrapper around _AttrProxy to provide the tracer. We can't 1533 # store it on _AttrProxy itself beceause we mimic the underlying class 1534 # (including its attributes). 1535 tracer = self 1536 1537 class AttrProxy(_AttrProxy): 1538 def __init__(self, base: Module, path: str) -> None: 1539 # Class is modified to be a subclass of torch.nn.Module 1540 # Warning: We blow away our own attributes here to mimic the base class 1541 # - so don't expect `self.x` to do anything useful. 1542 self.__class__ = type( 1543 base.__class__.__name__, 1544 (self.__class__, base.__class__), 1545 {}, 1546 ) 1547 self.__dict__ = base.__dict__ 1548 self.__class__.__module__ = base.__class__.__module__ 1549 self.__class__.__qualname__ = base.__class__.__qualname__ 1550 self.reset_proxy_mapping(base, path) 1551 1552 def reset_proxy_mapping(self, base: Module, path: str) -> None: 1553 tracer.proxy_paths[self] = path 1554 tracer.proxy_modules[self] = base 1555 1556 def __getattr__(self, name: str) -> AttrProxy: 1557 assert isinstance(self, Module) 1558 # Calling into torch.nn.Module.__getattr__ with super(), 1559 # That __getattr__ is patched to be module_getattr_wrapper in _symbolic_trace.py. 1560 # which then calls into _ModuleStackTracer.getattr 1561 attr_val = super().__getattr__(name) # type: ignore[misc] 1562 if isinstance(attr_val, AttrProxy): 1563 attr_val = tracer.proxy_modules[attr_val] 1564 elif not isinstance(attr_val, Module): 1565 return attr_val 1566 if attr_val not in tracer.attr_proxy_map: 1567 tracer.attr_proxy_map[attr_val] = AttrProxy( 1568 attr_val, tracer.proxy_paths[self] + "." + name 1569 ) 1570 else: 1571 # NOTE [caching AttrProxy]. Caching ensures a 1-1 mapping between AttrProxy and the actual attr_val. 1572 # 1. We reset the proxy_mapping to solve the diamond shape reference problem: we want to record the 1573 # path as A.B.D instead of A.C.D (the purpose of _ModuleStackTracer). 1574 # 2. Instead of creating a new AttrProxy, we just reset the proxy_mapping of existing one. This is to avoid 1575 # dynamo creating multiple guards for the same attr_val but different AttrProxy when exporting 1576 # a model that calls torch.compile (e.g when a model uses torch.cond.) 1577 tracer.attr_proxy_map[attr_val].reset_proxy_mapping( 1578 attr_val, tracer.proxy_paths[self] + "." + name 1579 ) 1580 return tracer.attr_proxy_map[attr_val] 1581 1582 def get_base(self) -> Module: 1583 return tracer.proxy_modules[self] 1584 1585 @property 1586 def _modules(self) -> Dict[str, AttrProxy]: 1587 assert "_modules" in self.__dict__ 1588 submodules = self.__dict__["_modules"] 1589 assert isinstance(submodules, dict) 1590 return { 1591 key: AttrProxy(value, tracer.proxy_paths[self] + "." + str(key)) 1592 for key, value in submodules.items() 1593 } 1594 1595 self.proxy_type = AttrProxy 1596 1597 def path_of_module(self, mod: Module) -> str: 1598 """ 1599 Use tracked access path during tracing instead of the default BFS behavior. 1600 Still use all the possible module paths to verify the result. 1601 """ 1602 if mod is self.scope_root: 1603 return "" 1604 1605 if isinstance(mod, _AttrProxy): 1606 return self.proxy_paths[mod] 1607 1608 try: 1609 return Tracer.path_of_module(self, mod) 1610 except NameError as e: 1611 raise _ModuleNotInstalledAsSubmoduleError from e 1612 1613 def getattr( 1614 self, attr: str, attr_val: object, parameter_proxy_cache: Dict[str, Proxy] 1615 ) -> object: 1616 if not isinstance(attr_val, Module) or isinstance(attr_val, fx.GraphModule): 1617 return super().getattr(attr, attr_val, parameter_proxy_cache) 1618 if isinstance(attr_val, _AttrProxy): 1619 return attr_val 1620 1621 # See NOTE [caching AttrProxy]. 1622 if attr_val not in self.attr_proxy_map: 1623 self.attr_proxy_map[attr_val] = self.proxy_type(attr_val, attr) 1624 else: 1625 self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr) 1626 return self.attr_proxy_map[attr_val] 1627 1628 def trace( # type: ignore[override] 1629 self, root: Union[Module, Callable], concrete_args: Optional[Dict[str, object]] 1630 ) -> fx.Graph: 1631 res = super().trace(root, concrete_args) 1632 1633 # Since we are making _AttrProxy mimic the original 1634 # submodule, when someone registers a module directly 1635 # to the tracer while tracing, the proxy object gets registered 1636 # first. So we need to replace the proxy modules with the real ones 1637 # This can happen during HOO tracing 1638 proxy_module_names_to_be_replaced: List[Tuple[str, _AttrProxy]] = [] 1639 for name, module in self.root.named_modules(): 1640 if module in self.proxy_modules: 1641 proxy_module_names_to_be_replaced.append((name, module)) 1642 1643 def _delete_proxy_attr(obj: Module, target: str) -> bool: 1644 # Copied from fx/graph_module.py 1645 # Customized it for proxy type 1646 atoms = target.split(".") 1647 path, target_submod = atoms[:-1], atoms[-1] 1648 assert isinstance(obj, Module) 1649 mod = obj 1650 1651 # Get the parent module 1652 for item in path: 1653 if not hasattr(mod, item): 1654 return False 1655 1656 mod = getattr(mod, item) 1657 1658 if not isinstance(mod, (_AttrProxy, Module)): 1659 return False 1660 1661 if not hasattr(mod, target_submod): 1662 return False 1663 1664 # At least the leaf module should be proxy type. 1665 if not isinstance(getattr(mod, target_submod), _AttrProxy): 1666 return False 1667 1668 delattr(mod, target_submod) 1669 return True 1670 1671 for proxy_module_name, proxy_module in proxy_module_names_to_be_replaced: 1672 _delete_proxy_attr(self.root, proxy_module_name) 1673 actual_module = self.proxy_modules[proxy_module] 1674 _assign_attr(actual_module, self.root, proxy_module_name) 1675 1676 return res 1677 1678 def call_module( 1679 self, 1680 m: Module, 1681 forward: Callable, 1682 args: Tuple[object, ...], 1683 kwargs: Dict[str, object], 1684 ) -> None: 1685 """PythonKeyTracer overrides call_module to avoid the scope handling, 1686 but we actually want it. 1687 """ 1688 from torch._dynamo import OptimizedModule 1689 1690 # FIXME (tmanlaibaatar) 1691 # When we call torch.compile inside HOO, we will end up 1692 # invoking a module that is not registered on the root. For 1693 # now, we just inline them. But once we start supporting 1694 # mark_strict in export, we do need to properly handle this. 1695 # Right now, it doesn't matter because current non-strict 1696 # use cases don't need to work with HOO. 1697 if isinstance(m, (OptimizedModule, GraphModule)): 1698 return forward(*args, **kwargs) 1699 1700 try: 1701 return Tracer.call_module(self, m, forward, args, kwargs) 1702 except _ModuleNotInstalledAsSubmoduleError as e: 1703 warnings.warn( 1704 f"Unable to find the path of the module {m}. " 1705 "This might be because the module was not properly registered " 1706 "as a submodule, which is not good practice. We will trace " 1707 "through the module without recording stack information." 1708 ) 1709 return forward(*args, **kwargs) 1710 1711 def is_leaf_module(self, m: Module, module_qualified_name: str) -> bool: 1712 return False 1713 1714 def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: 1715 """ 1716 Create node and add on metadata. 1717 Add nn_module_stack here instead of TracerBase, 1718 since calls to make_fx() might not want to record module stack metadata. 1719 Add torch_fn by looking at torch_fn_metadata and torch_fn_counts. 1720 Add stack_trace by filtering out forward() stack frames. 1721 """ 1722 node = super().create_node(*args, **kwargs) # type: ignore[arg-type] 1723 1724 # nn_module_stack 1725 if node.op not in ["placeholder", "output"]: 1726 if "nn_module_stack" not in node.meta: 1727 node.meta["nn_module_stack"] = self.module_stack 1728 # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]] 1729 for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): 1730 if isinstance(mod_cls, type): 1731 node.meta["nn_module_stack"][key] = ( 1732 fqn, 1733 mod_cls.__module__ + "." + mod_cls.__qualname__, 1734 ) 1735 1736 # torch_fn 1737 if ( 1738 node.op == "call_function" 1739 and self.torch_fn_metadata is not None 1740 and "torch_fn" not in node.meta 1741 ): 1742 node.meta["torch_fn"] = ( 1743 f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}", 1744 f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}", 1745 ) 1746 1747 # stack_trace 1748 if "stack_trace" not in node.meta and node.op not in ["placeholder", "output"]: 1749 user_frame_summary = CapturedTraceback.extract().summary() 1750 if user_frame_summary: 1751 # we retain frames from forward() calls, or ops 1752 # located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap) 1753 stack_trace = [ 1754 frame 1755 for frame in user_frame_summary 1756 if ( 1757 frame.name == "forward" 1758 or frame.filename.endswith("torch/__init__.py") 1759 ) 1760 ] 1761 # filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py 1762 # this is hardcoded, but leads to a much cleaner stack trace 1763 stack_trace = [ 1764 frame 1765 for frame in stack_trace 1766 if not ( 1767 frame.filename.endswith("fx/_symbolic_trace.py") 1768 or frame.filename.endswith("export/_trace.py") 1769 ) 1770 ] 1771 if ( 1772 stack_trace 1773 ): # empty list for strict mode, dynamo should handle stack_trace 1774 stack_trace = traceback.StackSummary.from_list(stack_trace) 1775 node.meta["stack_trace"] = "".join(stack_trace.format()).strip() 1776 1777 return node 1778 1779 1780class _MakefxTracer: 1781 def __init__( 1782 self, 1783 decomposition_table: Optional[Mapping[OpOverload, Callable]], 1784 tracing_mode: str, 1785 _allow_non_fake_inputs: bool, 1786 pre_dispatch: bool, 1787 record_module_stack: bool, 1788 _allow_fake_constant: bool, 1789 _error_on_data_dependent_ops: bool, 1790 ) -> None: 1791 # Configurations that are used to initialize the context managers and their states. 1792 # Should not modify them during tracing. 1793 self.decomposition_table: Dict[OpOverload, Callable] = dict( 1794 decomposition_table or {} 1795 ) 1796 self.decomposition_table.setdefault( 1797 torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel 1798 ) 1799 self.tracing_mode: str = tracing_mode 1800 self._allow_non_fake_inputs: bool = _allow_non_fake_inputs 1801 self.pre_dispatch: bool = pre_dispatch 1802 self.record_module_stack: bool = record_module_stack 1803 self._allow_fake_constant: bool = _allow_fake_constant 1804 self._error_on_data_dependent_ops: bool = _error_on_data_dependent_ops 1805 1806 # All context managers and their states should be initialized before tracing based on the inputs 1807 # and configurations. After tracing, their states should be cleaned except for shape_env. 1808 # Remember to specify how to intialize it from user inputs and from parent tracer whenever 1809 # adding new modes in _MakefxTracer. 1810 self.fake_tensor_mode: Optional[FakeTensorMode] = None 1811 self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext() 1812 self.proxy_function_mode: Union[ 1813 nullcontext, PreDispatchTorchFunctionMode 1814 ] = nullcontext() 1815 self.fx_tracer: Optional[PythonKeyTracer] = None 1816 self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext() 1817 self.torch_fn_metadata_mode: Union[ 1818 nullcontext, TorchFunctionMetadataMode 1819 ] = nullcontext() 1820 1821 def _checkpoint_modes(self) -> List[Any]: 1822 return [ 1823 self.fake_tensor_mode, 1824 self.proxy_mode, 1825 self.proxy_function_mode, 1826 self.fx_tracer, 1827 self.python_dispatcher_mode, 1828 self.torch_fn_metadata_mode, 1829 ] 1830 1831 def _restore_modes( 1832 self, 1833 prev_fake_tensor_mode: Optional[FakeTensorMode], 1834 prev_proxy_mode: Union[nullcontext, ProxyTorchDispatchMode], 1835 prev_proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode], 1836 prev_fx_tracer: Optional[PythonKeyTracer], 1837 prev_python_dispatcher_mode: Union[nullcontext, Any], 1838 prev_torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode], 1839 ) -> None: 1840 self.fake_tensor_mode = prev_fake_tensor_mode 1841 self.proxy_mode = prev_proxy_mode 1842 self.proxy_function_mode = prev_proxy_function_mode 1843 self.fx_tracer = prev_fx_tracer 1844 self.python_dispatcher_mode = prev_python_dispatcher_mode 1845 self.torch_fn_metadata_mode = prev_torch_fn_metadata_mode 1846 1847 @contextmanager 1848 def _init_modes_from_inputs( 1849 self, f: Callable, args: Tuple[object, ...] 1850 ) -> Generator[None, None, None]: 1851 prev_modes = self._checkpoint_modes() 1852 try: 1853 # Avoid importing sympy at a module level 1854 from .symbolic_shapes import ShapeEnv 1855 1856 if hasattr(f, "_orig_mod") and self.record_module_stack: 1857 scope_root = f._orig_mod 1858 self.fx_tracer = _ModuleStackTracer(scope_root) 1859 else: 1860 self.fx_tracer = PythonKeyTracer() 1861 1862 if self.tracing_mode == "fake": 1863 import torch._dynamo 1864 1865 fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) 1866 if fake_tensor_mode is None: 1867 import torch._functorch.config as _config 1868 1869 with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): 1870 fake_tensor_mode = FakeTensorMode( 1871 allow_fallback_kernels=True, 1872 allow_non_fake_inputs=self._allow_non_fake_inputs, 1873 shape_env=ShapeEnv(), 1874 static_shapes=True, 1875 ) 1876 self.fake_tensor_mode = fake_tensor_mode 1877 elif self.tracing_mode == "symbolic": 1878 import torch._dynamo 1879 1880 fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) 1881 if fake_tensor_mode is None: 1882 shape_env = ShapeEnv() 1883 import torch._functorch.config as _config 1884 1885 with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): 1886 fake_tensor_mode = FakeTensorMode( 1887 allow_fallback_kernels=False, 1888 allow_non_fake_inputs=self._allow_non_fake_inputs, 1889 shape_env=shape_env, 1890 ) 1891 assert ( 1892 fake_tensor_mode.shape_env is not None 1893 ), "shape_env should be set if tracing with 'symbolic'" 1894 self.fake_tensor_mode = fake_tensor_mode 1895 else: 1896 if not self.tracing_mode == "real": 1897 raise AssertionError( 1898 f"Unexpected tracing type: {self.tracing_mode}" 1899 ) 1900 1901 self._construct_modes_with_fx_tracer(self.fx_tracer) 1902 yield 1903 finally: 1904 self._restore_modes(*prev_modes) 1905 1906 def _construct_modes_with_fx_tracer(self, fx_tracer: _ProxyTracer) -> None: 1907 self.proxy_mode = ProxyTorchDispatchMode( 1908 fx_tracer, 1909 self.tracing_mode, 1910 pre_dispatch=self.pre_dispatch, 1911 _allow_fake_constant=self._allow_fake_constant, 1912 _error_on_data_dependent_ops=self._error_on_data_dependent_ops, 1913 ) 1914 1915 if self.pre_dispatch: 1916 self.proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer) 1917 1918 # pre-autograd tracing uses per-dispatch-key modes, 1919 # which requires the python dispatcher 1920 if self.tracing_mode == "symbolic" or self.pre_dispatch: 1921 self.python_dispatcher_mode = enable_python_dispatcher() 1922 1923 self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer) 1924 1925 @contextmanager 1926 def _init_modes_from_parent( 1927 self, parent_tracer: _MakefxTracer 1928 ) -> Generator[None, None, None]: 1929 # By default, subtracer creates new modes based on parent tracer's config. 1930 # However, there are cases where we want to share the same modes with parent tracer 1931 # For example, fake_tensor_mode, we want the example value's fake_mode of parent graph and subgraphs to be the same. 1932 prev_modes = self._checkpoint_modes() 1933 try: 1934 self.fake_tensor_mode = parent_tracer.fake_tensor_mode 1935 1936 def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer: 1937 if type(parent_tracer) == PythonKeyTracer: 1938 return PythonKeyTracer() 1939 elif type(parent_tracer) == _ModuleStackTracer: 1940 return _ModuleStackTracer(parent_tracer.scope_root) 1941 else: 1942 raise RuntimeError( 1943 f"Unexpected tracer type: {type(parent_tracer)}." 1944 ) 1945 1946 assert parent_tracer.fx_tracer is not None 1947 self.fx_tracer = _create_sub_fx_tracer(parent_tracer.fx_tracer) 1948 self._construct_modes_with_fx_tracer(self.fx_tracer) 1949 yield 1950 finally: 1951 self._restore_modes(*prev_modes) 1952 1953 def _trace_inner(self, f: Callable, *args: object) -> GraphModule: 1954 phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args) 1955 1956 def _wrap_fake(args: T) -> T: 1957 arg_count = 0 1958 1959 def inner_wrap_fake(x: object) -> object: 1960 nonlocal arg_count 1961 # TODO: it would be nice to line these up with the names 1962 # FX will choose for the placeholders, but we don't 1963 # actually know what the names will be at this point yet 1964 # NB: the Source here is actually meaningless 1965 from torch._dynamo.source import ConstantSource 1966 1967 assert self.fake_tensor_mode is not None 1968 source = ConstantSource(f"input{arg_count}") 1969 if isinstance(x, Tensor): 1970 arg_count += 1 1971 return self.fake_tensor_mode.from_tensor(x, source=source) 1972 # NB: don't match on bools 1973 elif type(x) is int and self.tracing_mode == "symbolic": 1974 assert ( 1975 self.fake_tensor_mode.shape_env is not None 1976 ), "shape_env should be set if tracing with 'symbolic'" 1977 return self.fake_tensor_mode.shape_env.create_symintnode( 1978 self.fake_tensor_mode.shape_env.create_symbol( 1979 x, source, positive=None 1980 ), 1981 hint=x, 1982 source=source, 1983 ) 1984 elif isinstance(x, torch.ScriptObject): 1985 return torch._library.fake_class_registry.maybe_to_fake_obj( 1986 self.fake_tensor_mode, x 1987 ) 1988 1989 assert not isinstance( 1990 x, FakeScriptObject 1991 ), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." 1992 return x 1993 1994 wrap_fn_map = { 1995 "real": lambda x: x, 1996 "fake": inner_wrap_fake, 1997 "symbolic": inner_wrap_fake, 1998 } 1999 return pytree.tree_map(wrap_fn_map[self.tracing_mode], args) 2000 2001 def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: 2002 if ( 2003 not hasattr(inspect.unwrap(f), "__code__") 2004 or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS 2005 ): 2006 # FX doesn't support varargs, so we gotta fake up a wrapper 2007 # TODO: Would be nice to fix this at the source... 2008 return fake_signature(f, len(phs)) 2009 return f 2010 2011 args = _wrap_fake(args) 2012 func = _wrap_func(f, phs) 2013 # We disable the autocast cache as the autocast cache causes type conversions on parameters to 2014 # check a cache, which introduces untracked tensors into the graph 2015 # 2016 # We also disable tracing by any other tensor proxy-based tracers except the current. The 2017 # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is 2018 # thus irrelevant to any external functional trace. 2019 proxy_mode: ProxyTorchDispatchMode = typing.cast( 2020 ProxyTorchDispatchMode, self.proxy_mode 2021 ) 2022 with ExitStack() as stack: 2023 stack.enter_context(decompose(self.decomposition_table)) 2024 if self.fake_tensor_mode: 2025 stack.enter_context(self.fake_tensor_mode) 2026 stack.enter_context(self.python_dispatcher_mode) 2027 stack.enter_context(self.proxy_function_mode) 2028 stack.enter_context(self.torch_fn_metadata_mode) 2029 stack.enter_context(proxy_mode) 2030 stack.enter_context(disable_autocast_cache()) 2031 stack.enter_context(_set_make_fx_tracer(self)) 2032 2033 assert self.fx_tracer is not None 2034 t = dispatch_trace( 2035 wrap_key(func, args, self.fx_tracer, self.pre_dispatch), 2036 tracer=self.fx_tracer, 2037 concrete_args=tuple(phs), 2038 ) 2039 2040 # TODO: kind of a bad way to do it, should maybe figure out a better way 2041 if self.tracing_mode == "symbolic": 2042 assert self.fake_tensor_mode is not None 2043 t.shape_env = self.fake_tensor_mode.shape_env 2044 return t 2045 2046 def trace(self, f: Callable, *args: object) -> fx.GraphModule: 2047 with self._init_modes_from_inputs(f, args): 2048 return self._trace_inner(f, *args) 2049 2050 def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: 2051 # Create a new tracer based on parent's config 2052 sub_tracer = _MakefxTracer( 2053 self.decomposition_table, 2054 "real", 2055 self._allow_non_fake_inputs, 2056 self.pre_dispatch, 2057 self.record_module_stack, 2058 self._allow_fake_constant, 2059 self._error_on_data_dependent_ops, 2060 ) 2061 with sub_tracer._init_modes_from_parent(self): 2062 return sub_tracer._trace_inner(f, *args) 2063 2064 2065_CURRENT_MAKE_FX_TRACER: Optional[_MakefxTracer] = None 2066 2067 2068@contextmanager 2069def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]: 2070 global _CURRENT_MAKE_FX_TRACER 2071 prev_tracer = _CURRENT_MAKE_FX_TRACER 2072 try: 2073 _CURRENT_MAKE_FX_TRACER = tracer 2074 yield 2075 finally: 2076 _CURRENT_MAKE_FX_TRACER = prev_tracer 2077 2078 2079def make_fx( 2080 f: Callable, 2081 decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, 2082 tracing_mode: str = "real", 2083 _allow_non_fake_inputs: bool = False, 2084 *, 2085 pre_dispatch: bool = False, 2086 record_module_stack: bool = False, 2087 _allow_fake_constant: bool = False, 2088 _error_on_data_dependent_ops: bool = True, 2089) -> Callable[..., GraphModule]: 2090 """ 2091 Given a function f, return a new function which when executed with valid 2092 arguments to f, returns an FX GraphModule representing the set of operations that 2093 were executed during the course of execution. 2094 """ 2095 2096 assert tracing_mode in ["real", "fake", "symbolic"] 2097 2098 make_fx_tracer = _MakefxTracer( 2099 decomposition_table, 2100 tracing_mode, 2101 _allow_non_fake_inputs, 2102 pre_dispatch, 2103 record_module_stack, 2104 _allow_fake_constant, 2105 _error_on_data_dependent_ops, 2106 ) 2107 2108 @functools.wraps(f) 2109 def wrapped(*args: object) -> GraphModule: 2110 return make_fx_tracer.trace(f, *args) 2111 2112 return wrapped 2113 2114 2115def get_torch_dispatch_modes() -> List[TorchDispatchMode]: 2116 return torch.utils._python_dispatch._get_current_dispatch_mode_stack() 2117 2118 2119# TODO: this is a legacy name, there is only ever one proxy mode as it's an 2120# infra mode 2121def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]: 2122 return get_proxy_mode() 2123 2124 2125def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]: 2126 """ 2127 Current the currently active proxy tracing mode, or None if 2128 we are not currently tracing. This includes pre-dispatch proxy 2129 tracing. 2130 """ 2131 pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch( 2132 torch._C._TorchDispatchModeKey.PROXY 2133 ) 2134 mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) 2135 assert ( 2136 pre_dispatch_mode is None or mode is None 2137 ), f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" 2138 return pre_dispatch_mode or mode 2139 2140 2141def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R: 2142 """ 2143 Call into the currently active proxy tracing mode to do a 2144 SymInt/SymFloat/SymBool dispatch trace on a function that operates on 2145 these arguments. 2146 """ 2147 mode = get_proxy_mode() 2148 assert mode 2149 # Have to do it manually, because we're not doing the normal torch 2150 # dispatch machinery which disables it for us 2151 with disable_proxy_modes_tracing(): 2152 # TODO: properly compute types 2153 types: List[Type] = [] 2154 return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value] 2155 2156 2157@contextmanager 2158def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, None]: 2159 return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY) 2160 2161 2162def maybe_handle_decomp( 2163 proxy_mode: ProxyTorchDispatchMode, 2164 op: OpOverload, 2165 args: Tuple[object, ...], 2166 kwargs: Dict[str, object], 2167) -> object: 2168 if op in CURRENT_DECOMPOSITION_TABLE: 2169 with proxy_mode: 2170 proxy_mode.decomp_layers += 1 2171 out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) 2172 proxy_mode.decomp_layers -= 1 2173 return out 2174 2175 return NotImplemented 2176 2177 2178def get_isolated_graphmodule( 2179 func: Callable, 2180 args: Tuple[object, ...], 2181 kwargs: Dict[str, object], 2182 tracing_mode: str = "real", 2183 decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, 2184) -> GraphModule: 2185 """A helper function used to get the GraphModule for the given func. 2186 2187 It's expected to be used in the ProxyTensor tracing context. 2188 It detaches the args and kwargs from the current tracer so that the trace of 2189 the current graph module can be created without any side-effects. 2190 """ 2191 wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) 2192 2193 with disable_proxy_modes_tracing(): 2194 gm = make_fx( 2195 wrapped, decomposition_table=decomposition_table, tracing_mode=tracing_mode 2196 )(all_args) 2197 return gm 2198 2199 2200def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: 2201 """A helper function for setting up unbacked_bindings on the destination FX graph.""" 2202 from .symbolic_shapes import compute_unbacked_bindings 2203 2204 # Can't use detect_fake_mode here, 2205 # 2206 # python test/distributed/_tensor/test_dtensor_compile.py -k 2207 # test_tp_compile_fullgraph_is_seq_parallel_False 2208 # 2209 # will fail. Very strange, it probably isn't right for them to be using 2210 # two fake modes there... 2211 fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) 2212 if fake_mode and fake_mode.shape_env: 2213 if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): 2214 assert isinstance(out_proxy, Proxy), out_proxy 2215 out_proxy.node.meta["unbacked_bindings"] = symbol_to_path 2216