1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import ast 5import builtins 6import collections 7import dataclasses 8import enum 9import functools 10import importlib 11import inspect 12import itertools 13import logging 14import math 15import os 16import re 17import sys 18import textwrap 19import types 20import weakref 21from contextlib import contextmanager 22from copy import deepcopy 23from inspect import currentframe, getframeinfo 24from typing import ( 25 Any, 26 Callable, 27 Dict, 28 List, 29 Optional, 30 Set, 31 Tuple, 32 Type, 33 TYPE_CHECKING, 34 Union, 35) 36from weakref import ReferenceType 37 38import torch 39import torch.utils._device 40from torch._C._dynamo.guards import ( 41 check_obj_id, 42 check_type_id, 43 dict_version, 44 DictGuardManager, 45 install_no_tensor_aliasing_guard, 46 install_object_aliasing_guard, 47 RootGuardManager, 48 TensorGuards, 49) 50from torch._dynamo.source import ( 51 is_from_flatten_script_object_source, 52 is_from_local_source, 53 is_from_optimizer_source, 54 TensorProperty, 55 TensorPropertySource, 56) 57from torch._guards import ( 58 CompileContext, 59 CompileId, 60 DuplicateInputs, 61 Guard, 62 GuardBuilderBase, 63 GuardEnvExpr, 64 GuardSource, 65 Source, 66) 67from torch._logging import structured 68from torch._utils_internal import justknobs_check 69from torch.fx.experimental.symbolic_shapes import ( 70 EqualityConstraint, 71 is_symbolic, 72 SYMPY_INTERP, 73) 74from torch.utils._traceback import format_frame, report_compile_source_on_error 75from torch.utils.weak import TensorWeakRef 76 77from . import config, convert_frame, exc, mutation_guard 78from .eval_frame import set_guard_error_hook 79from .source import ( 80 AttrProxySource, 81 AttrSource, 82 ChainedSource, 83 ConstDictKeySource, 84 DefaultsSource, 85 FlattenScriptObjectSource, 86 FSDPNNModuleSource, 87 GetItemSource, 88 GlobalSource, 89 GlobalStateSource, 90 GlobalWeakRefSource, 91 GradSource, 92 LocalSource, 93 NNModuleSource, 94 NumpyTensorSource, 95 ODictGetItemSource, 96 OptimizerSource, 97 ScriptObjectQualifiedNameSource, 98 ShapeEnvSource, 99 SubclassAttrListSource, 100 TupleIteratorGetItemSource, 101 TypeSource, 102 UnspecializedBuiltinNNModuleSource, 103 UnspecializedNNModuleSource, 104 UnspecializedParamBufferSource, 105 WeakRefCallSource, 106) 107from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401 108from .utils import ( 109 common_constant_types, 110 dict_keys_repr, 111 get_custom_getattr, 112 get_torch_function_mode_stack, 113 guard_failures, 114 istype, 115 key_is_id, 116 key_to_id, 117 orig_code_map, 118 tensor_always_has_static_shape, 119 tuple_iterator_getitem, 120 tuple_iterator_len, 121 unpatched_nn_module_getattr, 122 verify_guard_fn_signature, 123) 124 125 126try: 127 import numpy as np 128except ModuleNotFoundError: 129 np = None # type: ignore[assignment] 130 131 132if TYPE_CHECKING: 133 from sympy import Symbol 134 135 136log = logging.getLogger(__name__) 137guards_log = torch._logging.getArtifactLogger(__name__, "guards") 138recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles") 139recompiles_verbose_log = torch._logging.getArtifactLogger( 140 __name__, "recompiles_verbose" 141) 142verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") 143 144 145class GuardManager: 146 """ 147 A helper class that contains the root guard manager. An instance of this 148 class is stored in the Dynamo cache entry, so that the cache entry can 149 access the RootGuardManager stored in the "root" attribute and directly call 150 the check_nopybind from C++. 151 """ 152 153 def __init__(self): 154 self.root = RootGuardManager() 155 156 self.closure_vars = None 157 self.args = None 158 self.code_parts = [] 159 self.verbose_code_parts = None 160 self.global_scope = None 161 self.guard_fail_fn = None 162 self.cache_entry = None 163 self.extra_state = None 164 self.id_matched_objs = None 165 self.no_tensor_aliasing_sources = [] 166 167 self.print_no_tensor_aliasing_guard = True 168 169 @contextmanager 170 def _preserve_print_no_tensor_aliasing_flag(self): 171 self.print_no_tensor_aliasing_guard = True 172 try: 173 yield 174 finally: 175 self.print_no_tensor_aliasing_guard = True 176 177 def get_guard_lines(self, guard): 178 guard_name = guard.__class__.__name__ 179 parts = guard.verbose_code_parts() 180 parts = [guard_name + ": " + part for part in parts] 181 return parts 182 183 def get_manager_line(self, guard_manager, accessor_str=None): 184 source = guard_manager.get_source() 185 t = guard_manager.__class__.__name__ 186 s = t + ": source=" + source 187 if accessor_str: 188 s += ", " + accessor_str 189 return s 190 191 def construct_dict_manager_string(self, mgr, body): 192 for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()): 193 body.writeline(f"KeyValueManager pair at index={idx}") 194 with body.indent(): 195 if key_mgr: 196 body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}") 197 self.construct_manager_string(key_mgr, body) 198 199 if val_mgr: 200 body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}") 201 self.construct_manager_string(val_mgr, body) 202 203 def construct_manager_string(self, mgr, body): 204 with body.indent(): 205 for guard in mgr.get_leaf_guards(): 206 if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): # type: ignore[attr-defined] 207 if self.print_no_tensor_aliasing_guard: 208 self.print_no_tensor_aliasing_guard = False 209 body.writelines(self.get_guard_lines(guard)) 210 else: 211 body.writelines( 212 [ 213 guard.__class__.__name__, 214 ] 215 ) 216 else: 217 body.writelines(self.get_guard_lines(guard)) 218 219 # This works for both DictGuardManager and SubclassedDictGuardManager 220 if isinstance(mgr, DictGuardManager): 221 self.construct_dict_manager_string(mgr, body) 222 223 # General case of GuardManager/RootGuardManager 224 for accessor, child_mgr in zip( 225 mgr.get_accessors(), mgr.get_child_managers() 226 ): 227 body.writeline( 228 self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}") 229 ) 230 self.construct_manager_string(child_mgr, body) 231 232 def __str__(self): 233 from torch._inductor.utils import IndentedBuffer 234 235 class IndentedBufferWithPrefix(IndentedBuffer): 236 def prefix(self): 237 return "| " * (self._indent * self.tabwidth) 238 239 def writeline(self, line, skip_prefix=False): 240 if skip_prefix: 241 super().writeline(line) 242 else: 243 super().writeline("+- " + line) 244 245 with self._preserve_print_no_tensor_aliasing_flag(): 246 body = IndentedBufferWithPrefix() 247 body.tabwidth = 1 248 body.writeline("", skip_prefix=True) 249 body.writeline("TREE_GUARD_MANAGER:", skip_prefix=True) 250 body.writeline("RootGuardManager") 251 self.construct_manager_string(self.root, body) 252 for guard in self.root.get_epilogue_lambda_guards(): 253 body.writelines(self.get_guard_lines(guard)) 254 return body.getvalue() 255 256 def check(self, x): 257 # Only needed for debugging purposes. 258 return self.root.check(x) 259 260 def check_verbose(self, x): 261 # Only needed for debugging purposes. 262 return self.root.check_verbose(x) 263 264 def populate_code_parts_for_debugging(self): 265 # This should be called when the guard manager is fully populated 266 tensor_aliasing_guard_seen = False 267 268 def get_code_parts(leaf_guard): 269 code_parts = [] 270 for verbose_code_part in leaf_guard.verbose_code_parts(): 271 code_part = verbose_code_part.split("#")[0].rstrip() 272 code_parts.append(code_part) 273 return code_parts 274 275 def visit(mgr): 276 nonlocal tensor_aliasing_guard_seen 277 for guard in mgr.get_leaf_guards(): 278 if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING): # type: ignore[attr-defined] 279 if not tensor_aliasing_guard_seen: 280 self.code_parts.extend(get_code_parts(guard)) 281 tensor_aliasing_guard_seen = True 282 else: 283 self.code_parts.extend(get_code_parts(guard)) 284 285 for child_mgr in mgr.get_child_managers(): 286 visit(child_mgr) 287 288 visit(self.root) 289 290 291def from_numpy(a): 292 # If not numpy array, piggy back on e.g. tensor guards to check type 293 return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a 294 295 296# For user stack printing 297@functools.lru_cache(None) 298def uninteresting_files(): 299 import torch._dynamo.external_utils 300 301 mods = [ 302 torch._dynamo.external_utils, 303 ] 304 return {inspect.getfile(m) for m in mods} 305 306 307CLOSURE_VARS = { 308 "___check_type_id": check_type_id, 309 "___check_obj_id": check_obj_id, 310 "___odict_getitem": collections.OrderedDict.__getitem__, 311 "___key_to_id": key_to_id, 312 "___dict_version": dict_version, 313 "___dict_contains": lambda a, b: a in b, 314 "___tuple_iterator_len": tuple_iterator_len, 315 "___tuple_iterator_getitem": tuple_iterator_getitem, 316 "__math_isnan": math.isnan, 317 "__numpy_isnan": None if np is None else np.isnan, 318 "inf": float("inf"), 319 "__load_module": importlib.import_module, 320 "utils_device": torch.utils._device, 321 "device": torch.device, 322 "___from_numpy": from_numpy, 323 "___as_tensor": torch.as_tensor, 324 "torch": torch, 325 "inspect": inspect, 326} 327 328if sys.version_info[:2] <= (3, 8): 329 # [Note: Python Version <= 3.8] 330 # This branch should be dropped when we drop support for Python 3.8. 331 # Reason: 'ast.unparse' function was introduced in Python 3.9. 332 333 try: 334 import astunparse # type: ignore[import] 335 336 def _ast_unparse(node: ast.AST) -> str: 337 return astunparse.unparse(node).replace("\n", "") 338 339 HAS_UNPARSE_FUNCTIONS = True 340 except ImportError: 341 HAS_UNPARSE_FUNCTIONS = False 342else: 343 HAS_UNPARSE_FUNCTIONS = True 344 345 def _ast_unparse(node: ast.AST) -> str: 346 return ast.unparse(node).replace("\n", "") 347 348 349def strip_function_call(name): 350 """ 351 "___odict_getitem(a, 1)" => "a" 352 "a.layers[slice(2)][0]._xyz" ==> "a" 353 "getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a" 354 "getattr(getattr(a.x[3], '0'), '3')" ==> "a" 355 "a.layers[slice(None, -1, None)][0]._xyz" ==> "a" 356 """ 357 # recursively find valid object name in function 358 valid_name = re.compile("[A-Za-z_].*") 359 curr = "" 360 for char in name: 361 if char in " (": 362 curr = "" 363 elif char in "),[]": 364 if curr and curr != "None" and valid_name.match(curr): 365 return strip_function_call(curr) 366 else: 367 curr += char 368 369 return strip_getattr_getitem(name) 370 371 372def strip_getattr_getitem(name): 373 """ 374 "a[1]" => "a" 375 "a.foo" => "a" 376 """ 377 return re.split(r"[.\[]", name)[0] 378 379 380def get_verbose_code_part(code_part: str, guard: Guard) -> str: 381 extra = "" 382 if guard.user_stack: 383 for fs in reversed(guard.user_stack): 384 if fs.filename not in uninteresting_files(): 385 extra = f" # {format_frame(fs, line=True)}" 386 break 387 elif guard.stack: 388 extra = f" # {format_frame(guard.stack.summary()[-1])}" 389 390 return f"{code_part:<60}{extra}" 391 392 393def get_verbose_code_parts( 394 code_parts: Union[str | List[str]], guard: Guard 395) -> List[str]: 396 if not isinstance(code_parts, list): 397 code_parts = [code_parts] 398 return [get_verbose_code_part(code_part, guard) for code_part in code_parts] 399 400 401def convert_to_concrete_values(size_or_stride): 402 converted: List[Optional[int]] = [] 403 for dim in size_or_stride: 404 if not is_symbolic(dim): 405 converted.append(dim) 406 else: 407 assert isinstance(dim, torch.SymInt) 408 converted.append(dim.node.maybe_as_int()) 409 return converted 410 411 412def get_tensor_guard_code_part(value, name, sizes, strides): 413 pytype = type(value) 414 dispatch_key = ( 415 torch._C._dispatch_keys(value) | torch._C._dispatch_tls_local_include_set() 416 ) - torch._C._dispatch_tls_local_exclude_set() 417 dtype = value.dtype 418 device_index = value.device.index 419 requires_grad = value.requires_grad 420 guard_str = ( 421 f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, " 422 f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})" 423 ) 424 return guard_str 425 426 427def get_key_index(dct, key): 428 return list(dct.keys()).index(key) 429 430 431def get_key_index_source(source, index): 432 return f"list({source}.keys())[{index}]" 433 434 435@dataclasses.dataclass(frozen=True) 436class NNModuleAttrAccessorInfo: 437 # Represents where is the attr name is present in the nn module attribute 438 # access 439 440 # Tells that the attribute can be accessed via __dict__ 441 present_in_generic_dict: bool = False 442 443 # Either the actual name or _parameters/_buffers/_modules 444 l1_key: Optional[str] = None 445 446 # Actual paramter/buffer/submodule name 447 l2_key: Optional[str] = None 448 449 450def getitem_on_dict_manager( 451 source, base_guard_manager, base_example_value, example_value, guard_manager_enum 452): 453 base_source_name = source.base.name() 454 source_name = source.name() 455 if isinstance(source.index, ConstDictKeySource): 456 index = source.index.index 457 else: 458 assert isinstance(base_example_value, dict) 459 index = get_key_index(base_example_value, source.index) 460 461 key_source = get_key_index_source(base_source_name, index) 462 key_example_value = list(base_example_value.keys())[index] 463 if isinstance(key_example_value, (int, str)): 464 value_source = f"{base_source_name}[{key_example_value!r}]" 465 else: 466 value_source = f"{base_source_name}[{key_source}]" 467 if not isinstance(source.index, ConstDictKeySource): 468 # We have to insert a key manager guard here 469 # TODO - source debug string is probably wrong here. 470 base_guard_manager.get_key_manager( 471 index=index, 472 source=key_source, 473 example_value=source.index, 474 guard_manager_enum=GuardManagerType.GUARD_MANAGER, 475 ).add_equals_match_guard( 476 source.index, [f"{key_source} == {key_example_value!r}"] 477 ) 478 479 return base_guard_manager.get_value_manager( 480 index=index, 481 source=value_source, 482 example_value=example_value, 483 guard_manager_enum=guard_manager_enum, 484 ) 485 486 487def match_on_id_for_tensor(guard): 488 source = guard.originating_source 489 return source.is_dict_key() and not isinstance(source, GradSource) 490 491 492# The ready to eval generated code (possibly multiple parts) for a guard, plus 493# the original guard object that created it for provenance 494@dataclasses.dataclass 495class GuardCodeList: 496 code_list: List[str] 497 guard: Guard 498 499 500class GuardManagerType(enum.Enum): 501 GUARD_MANAGER = 1 502 DICT_GUARD_MANAGER = 2 503 DICT_SUBCLASS_GUARD_MANAGER = 3 504 505 506class GuardBuilder(GuardBuilderBase): 507 def __init__( 508 self, 509 id_ref: Callable[[Any], str], 510 source_ref: Callable[[Source], str], 511 lookup_weakrefs: Callable[[object], ReferenceType[object]], 512 local_scope: Dict[str, object], 513 global_scope: Dict[str, object], 514 guard_manager: Optional[GuardManager], 515 check_fn_manager: CheckFunctionManager, 516 ): 517 self.id_ref = id_ref 518 self.source_ref = source_ref 519 self.lookup_weakrefs = lookup_weakrefs 520 self.scope: Dict[str, Dict[str, object]] = {"L": local_scope, "G": global_scope} 521 self.scope["__builtins__"] = builtins.__dict__.copy() 522 for ( 523 name, 524 package_module, 525 ) in torch.package.package_importer._package_imported_modules.items(): 526 name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_") 527 # Write the package module into the scope so that we can import it 528 self.scope["__builtins__"][name] = package_module 529 # Write the demangled name to the scope so that we can use it 530 self.scope[name] = package_module 531 self.guard_manager = guard_manager 532 533 self.argnames: List[str] = [] 534 # Code is python expression strings generated for each guard 535 self.code: List[GuardCodeList] = [] 536 # shape_env_code is only used by builder and is used for 537 # shape env code. This exists only because we need to make sure 538 # shape env guards get run after tensor match guards (since the 539 # tensor match guards make sure we actually have tensors) 540 self.shape_env_code: List[GuardCodeList] = [] 541 542 # [Note - On Eager Tensor Guards] 543 # Most of the time, we generate Python code in a guard to directly 544 # check various properties. However, tensors are a bit special; 545 # it is too slow to check their properties one-by-one in Python. 546 # Instead, there is a C++ function TensorGuards.check which takes 547 # all of the tensor arguments and checks them all against compile-time 548 # examples entirely in C++. Thus, every time we process a 549 # TENSOR_MATCH guard, we just add another entry to 550 # tensor_check_names/tensor_check_examples, saying "for this local, 551 # check it against this example", and it all ends up getting 552 # swept up into a single call to ___check_tensors. Invariant: 553 # len(tensor_check_names) == len(tensor_check_examples). 554 # TODO: something here 555 self.tensor_check_names: List[str] = [] 556 self.tensor_check_examples: List[torch.Tensor] = [] 557 self.tensor_check_guards: List[Guard] = [] 558 self.tensor_check_guard_managers: List[GuardManager] = [] 559 560 self.check_fn_manager: CheckFunctionManager = check_fn_manager 561 562 # Collect the ids of dicts which need key order guarding. source_name is 563 # not sufficient because for nn modules, we can have different sources 564 # to access the same object - self._module["param"] is same as 565 # self.param. 566 self.key_order_guarded_dict_ids = set() 567 for source_name in self.check_fn_manager.output_graph.guard_on_key_order: 568 self.key_order_guarded_dict_ids.add(id(self.get(source_name))) 569 570 # Keep track of weak references of objects with ID_MATCH guard. This 571 # info is stored alongside optimized_code and check_fn and is used to 572 # limit the number of cache entries with same ID_MATCH'd object. 573 self.id_matched_objs: Dict[str, ReferenceType[object]] = {} 574 575 # Save the guard managers to avoid repeatedly traversing sources. 576 self._cached_guard_managers: Dict[ 577 str, torch._C._dynamo.guards.GuardManager 578 ] = {} 579 580 self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set() 581 582 def guard_on_dict_keys_and_ignore_order(self, example_value, guard): 583 dict_mgr = self.get_guard_manager(guard) 584 if isinstance(dict_mgr, DictGuardManager): 585 raise NotImplementedError( 586 "Not expecting a DictGuardManager. Seems like Dynamo incorrectly " 587 f"added the dict to tx.output.guard_on_key_order for {guard.name}" 588 ) 589 590 # Iterate over the dicts and install a dict_getitem_manager. 591 dict_source = guard.originating_source.name() 592 for key in example_value.keys(): 593 value = example_value[key] 594 value_source = GetItemSource(guard.originating_source, index=key) 595 guard_manager_enum = self.get_guard_manager_type( 596 value_source, example_value 597 ) 598 dict_mgr.dict_getitem_manager( 599 key=key, 600 source=f"{dict_source}[{key!r}]", 601 example_value=value, 602 guard_manager_enum=guard_manager_enum, 603 ) 604 605 def guard_on_dict_keys_and_order(self, value, guard): 606 # Add key managers for the DictGuardManager. Then add either an 607 # ID_MATCH or EQUALS_MATCH guard on the key. 608 dict_mgr = self.get_guard_manager(guard) 609 if not isinstance(dict_mgr, DictGuardManager): 610 raise NotImplementedError( 611 "Expecting a DictGuardManager. Seems like Dynamo forgot " 612 f"to set the right guard manager enum for {guard.name}" 613 ) 614 assert isinstance(dict_mgr, DictGuardManager) 615 616 for idx, key in enumerate(value.keys()): 617 key_source = get_key_index_source(guard.name, idx) 618 key_manager = dict_mgr.get_key_manager( 619 index=idx, 620 source=key_source, 621 example_value=key, 622 guard_manager_enum=GuardManagerType.GUARD_MANAGER, 623 ) 624 if key_is_id(key): 625 # Install ID_MATCH guard 626 id_val = self.id_ref(key) 627 key_manager.add_id_match_guard( 628 id_val, 629 get_verbose_code_parts( 630 f"__check_obj_id({key_source}, {id_val})", guard 631 ), 632 ) 633 else: 634 # Install EQUALS_MATCH guard 635 key_manager.add_equals_match_guard( 636 key, get_verbose_code_parts(f"{key_source} == {key!r}", guard) 637 ) 638 639 def getattr_on_nn_module( 640 self, 641 source, 642 base_guard_manager, 643 base_example_value, 644 example_value, 645 base_source_name, 646 source_name, 647 guard_manager_enum, 648 ): 649 """ 650 This tries to avoid calling the expensive nn module custom getattr method by 651 checking if the attribute is accessible via __dict__. For attributes that 652 are not accessible via __dict__ (like descriptors), we fallback to 653 PyObject_GetAttr. 654 655 There are two cases that we optimize for 656 1) attributes present directly in __dict__, e.g training. 657 2) parameters/buffers/modules - they can be accessed via _parameters, 658 _buffers, _modules keys in __dict__. For example, mod.linear can be 659 accessed as mod.__dict__["_parameters"]["linear"] 660 661 The most common and expensive case for nn module guards is of type 662 mod.submod1.submod2.submod3.training. We avoid the python getattr of nn 663 modules by going through the __dict__. 664 """ 665 666 def getitem_on_dict_mgr( 667 mgr, key, source_name, base_example_value, example_value, guard_manager_enum 668 ): 669 if isinstance(mgr, DictGuardManager): 670 # Case where the user code relies on key order, e.g., 671 # named_parameters 672 index = get_key_index(base_example_value, key) 673 674 # Install the key manager and add equals match guard 675 key_source = f"list({source_name}.keys())[{index!r}]" 676 mgr.get_key_manager( 677 index=index, 678 source=key_source, 679 example_value=key, 680 guard_manager_enum=GuardManagerType.GUARD_MANAGER, 681 ).add_equals_match_guard(key, [f"{key_source} == {key!r}"]) 682 683 # Install the value manager 684 return mgr.get_value_manager( 685 index=index, 686 source=source_name, 687 example_value=example_value, 688 guard_manager_enum=guard_manager_enum, 689 ) 690 else: 691 return mgr.dict_getitem_manager( 692 key=key, 693 source=source_name, 694 example_value=example_value, 695 guard_manager_enum=guard_manager_enum, 696 ) 697 698 attr_name = source.member 699 mod_dict = base_example_value.__dict__ 700 701 all_class_attribute_names: Set[str] = set() 702 for x in inspect.getmro(base_example_value.__class__): 703 all_class_attribute_names.update(x.__dict__.keys()) 704 705 accessor_info = NNModuleAttrAccessorInfo(False, None, None) 706 707 if attr_name in mod_dict: 708 accessor_info = NNModuleAttrAccessorInfo(True, attr_name, None) 709 elif "_parameters" in mod_dict and attr_name in mod_dict["_parameters"]: 710 accessor_info = NNModuleAttrAccessorInfo(True, "_parameters", attr_name) 711 elif "_buffers" in mod_dict and attr_name in mod_dict["_buffers"]: 712 accessor_info = NNModuleAttrAccessorInfo(True, "_buffers", attr_name) 713 elif ( 714 attr_name not in all_class_attribute_names 715 and "_modules" in mod_dict 716 and attr_name in mod_dict["_modules"] 717 ): 718 # Check test_attr_precedence test - instance attributes always take precedence unless its an nn.Module. 719 accessor_info = NNModuleAttrAccessorInfo(True, "_modules", attr_name) 720 721 if not accessor_info.present_in_generic_dict: 722 # The attribute can be accessed by __getattribute__ call, so rely on 723 # PyObject_GetAttr 724 return base_guard_manager.getattr_manager( 725 attr=source.member, 726 source=source_name, 727 example_value=example_value, 728 guard_manager_enum=guard_manager_enum, 729 ) 730 else: 731 assert accessor_info.l1_key 732 l1_key = accessor_info.l1_key 733 l2_key = accessor_info.l2_key 734 735 # Set source strings for debug info 736 mod_dict_source = f"{base_source_name}.__dict__" 737 l1_source_name = l2_source_name = None 738 l1_value = l2_value = None 739 l1_guard_manager_enum = l2_guard_manager_enum = None 740 if l2_key: 741 l1_source = AttrSource(source.base, l1_key) 742 l1_source_name = l1_source.name() 743 l1_value = mod_dict[l1_key] 744 # do not guard on key order for _parameters etc unless the user code 745 # actually needs the key order (e.g. calling named_parameters) 746 l1_guard_manager_enum = self.get_guard_manager_type(l1_source, l1_value) 747 748 l2_source_name = source_name 749 l2_value = example_value 750 l2_guard_manager_enum = self.get_guard_manager_type( 751 source, example_value 752 ) 753 else: 754 l1_source_name = source_name 755 l1_value = example_value 756 l1_guard_manager_enum = self.get_guard_manager_type( 757 source, example_value 758 ) 759 760 # Get __dict__ accessor. No need to guard on dict key order, so use base 761 # Guard Manager 762 mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager( 763 source=mod_dict_source, 764 example_value=mod_dict, 765 guard_manager_enum=GuardManagerType.GUARD_MANAGER, 766 ) 767 768 l1_mgr = getitem_on_dict_mgr( 769 mgr=mod_generic_dict_manager, 770 key=l1_key, 771 source_name=l1_source_name, 772 base_example_value=mod_dict, 773 example_value=l1_value, 774 guard_manager_enum=l1_guard_manager_enum, 775 ) 776 777 if l2_key: 778 return getitem_on_dict_mgr( 779 mgr=l1_mgr, 780 key=l2_key, 781 source_name=l2_source_name, 782 base_example_value=l1_value, 783 example_value=l2_value, 784 guard_manager_enum=l2_guard_manager_enum, 785 ) 786 return l1_mgr 787 788 def requires_key_order_guarding(self, source): 789 source_name = source.name() 790 if source_name == "": 791 return False 792 obj_id = id(self.get(source_name)) 793 return obj_id in self.key_order_guarded_dict_ids 794 795 def get_guard_manager_type(self, source, example_value): 796 guard_manager_enum = GuardManagerType.GUARD_MANAGER 797 if self.requires_key_order_guarding(source): 798 assert isinstance(example_value, dict) 799 # If keys method is not overriden, we can use PyDict_Next to get key 800 # orderings. Read more in guards.cpp 801 if type(example_value).keys is type({}).keys: 802 guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER 803 else: 804 guard_manager_enum = GuardManagerType.DICT_SUBCLASS_GUARD_MANAGER 805 return guard_manager_enum 806 807 def manager_guards_on_keys(self, mgr_enum): 808 return ( 809 mgr_enum == GuardManagerType.DICT_GUARD_MANAGER 810 or mgr_enum == GuardManagerType.DICT_SUBCLASS_GUARD_MANAGER 811 ) 812 813 def get_global_guard_manager(self): 814 assert self.guard_manager # to make mypy happy 815 return self.guard_manager.root.globals_dict_manager( 816 f_globals=self.scope["G"], 817 source="G", 818 example_value=self.scope["G"], 819 guard_manager_enum=GuardManagerType.GUARD_MANAGER, 820 ) 821 822 def get_guard_manager_from_source(self, source): 823 assert self.guard_manager # to make mypy happy 824 root_guard_manager = self.guard_manager.root 825 826 example_value = None 827 source_name = source.name() 828 829 if source_name != "" and source_name in self._cached_guard_managers: 830 return self._cached_guard_managers[source_name] 831 832 if source_name != "": 833 example_value = self.get(source_name) 834 835 guard_manager_enum = self.get_guard_manager_type(source, example_value) 836 837 # Get base manager related information 838 base_source_name = None 839 base_example_value = None 840 base_guard_manager = None 841 base_guard_manager_enum = GuardManagerType.GUARD_MANAGER 842 if isinstance(source, ChainedSource): 843 base_source_name = source.base.name() 844 base_example_value = self.get(base_source_name) 845 base_guard_manager = self.get_guard_manager_from_source(source.base) 846 base_guard_manager_enum = self.get_guard_manager_type( 847 source.base, base_example_value 848 ) 849 850 # Use istype instead of isinstance to check for exact type of source. 851 if istype(source, LocalSource): 852 # RootGuardManager accepts a dict but still its not a 853 # DictGuardManager because we will eventually move to 854 # fastlocals. 855 out = root_guard_manager.dict_getitem_manager( 856 key=source.local_name, 857 source=source_name, 858 example_value=example_value, 859 guard_manager_enum=guard_manager_enum, 860 ) 861 elif istype(source, GlobalSource): 862 # Global manager accepts a dict but it is not a DictGuardManager 863 # because globals dict is big and we typically guard on a very 864 # selected items on globals. 865 out = self.get_global_guard_manager().dict_getitem_manager( 866 key=source.global_name, 867 source=source_name, 868 example_value=example_value, 869 guard_manager_enum=guard_manager_enum, 870 ) 871 elif istype(source, GlobalWeakRefSource): 872 out = self.get_global_guard_manager().global_weakref_manager( 873 global_name=source.global_name, 874 source=source_name, 875 example_value=example_value, 876 guard_manager_enum=guard_manager_enum, 877 ) 878 elif istype(source, GlobalStateSource): 879 # Don't do anything here. We guard on global state completely in 880 # C++. So just return the root mgr. 881 return root_guard_manager 882 elif istype(source, ShapeEnvSource): 883 return root_guard_manager 884 elif istype(source, TypeSource): 885 assert base_guard_manager # to make mypy happy 886 out = base_guard_manager.type_manager( 887 source=source_name, 888 example_value=example_value, 889 guard_manager_enum=guard_manager_enum, 890 ) 891 elif istype( 892 source, 893 ( 894 OptimizerSource, 895 NNModuleSource, 896 UnspecializedNNModuleSource, 897 UnspecializedBuiltinNNModuleSource, 898 FSDPNNModuleSource, 899 ), 900 ): 901 assert base_guard_manager # to make mypy happy 902 out = base_guard_manager 903 elif istype(source, GradSource): 904 assert base_guard_manager # to make mypy happy 905 out = base_guard_manager.grad_manager( 906 source=source_name, 907 example_value=example_value, 908 guard_manager_enum=guard_manager_enum, 909 ) 910 elif istype(source, (AttrSource, UnspecializedParamBufferSource)): 911 assert base_guard_manager # to make mypy happy 912 913 if ( 914 isinstance(base_example_value, torch.nn.Module) 915 and get_custom_getattr(base_example_value) 916 is unpatched_nn_module_getattr 917 ): 918 out = self.getattr_on_nn_module( 919 source, 920 base_guard_manager, 921 base_example_value, 922 example_value, 923 base_source_name, 924 source_name, 925 guard_manager_enum, 926 ) 927 else: 928 out = base_guard_manager.getattr_manager( 929 attr=source.member, 930 source=source_name, 931 example_value=example_value, 932 guard_manager_enum=guard_manager_enum, 933 ) 934 elif istype(source, GetItemSource): 935 assert base_guard_manager # to make mypy happy 936 if isinstance(base_example_value, (dict, collections.OrderedDict)): 937 # TODO(anijain2305) - Consider isolating GetItemSource and 938 # DictGetItemSource (or maybe use ODictGetItemSource for 939 # dicts) so that GetItemSource is only for non dict objects. 940 if isinstance(base_guard_manager, DictGuardManager): 941 assert self.manager_guards_on_keys(base_guard_manager_enum) 942 out = getitem_on_dict_manager( 943 source, 944 base_guard_manager, 945 base_example_value, 946 example_value, 947 guard_manager_enum, 948 ) 949 else: 950 if isinstance(source.index, ConstDictKeySource): 951 raise RuntimeError( 952 "Expecting clean index here. Likely Dynamo forgot to mark" 953 " a dict as guard_on_key_order" 954 ) 955 out = base_guard_manager.dict_getitem_manager( 956 key=source.index, 957 source=source_name, 958 example_value=example_value, 959 guard_manager_enum=guard_manager_enum, 960 ) 961 elif isinstance(base_example_value, list) and not source.index_is_slice: 962 out = base_guard_manager.list_getitem_manager( 963 key=source.index, 964 source=source_name, 965 example_value=example_value, 966 guard_manager_enum=guard_manager_enum, 967 ) 968 elif isinstance(base_example_value, tuple) and not source.index_is_slice: 969 out = base_guard_manager.tuple_getitem_manager( 970 key=source.index, 971 source=source_name, 972 example_value=example_value, 973 guard_manager_enum=guard_manager_enum, 974 ) 975 else: 976 index = source.index 977 if source.index_is_slice: 978 index = source.unpack_slice() 979 out = base_guard_manager.getitem_manager( 980 key=index, 981 source=source_name, 982 example_value=example_value, 983 guard_manager_enum=guard_manager_enum, 984 ) 985 elif istype(source, ODictGetItemSource): 986 if isinstance(base_guard_manager, DictGuardManager): 987 assert self.manager_guards_on_keys(base_guard_manager_enum) 988 out = getitem_on_dict_manager( 989 source, 990 base_guard_manager, 991 base_example_value, 992 example_value, 993 guard_manager_enum, 994 ) 995 else: 996 assert base_guard_manager # to make mypy happy 997 out = base_guard_manager.dict_getitem_manager( 998 key=source.index, 999 source=source_name, 1000 example_value=example_value, 1001 guard_manager_enum=guard_manager_enum, 1002 ) 1003 elif istype(source, DefaultsSource): 1004 assert base_guard_manager # to make mypy happy 1005 assert callable(base_example_value) 1006 if not source.is_kw: 1007 out = base_guard_manager.func_defaults_manager( 1008 source=base_source_name, 1009 example_value=base_example_value.__defaults__, 1010 guard_manager_enum=GuardManagerType.GUARD_MANAGER, 1011 ).getitem_manager( 1012 key=source.idx_key, 1013 source=source_name, 1014 example_value=example_value, 1015 guard_manager_enum=guard_manager_enum, 1016 ) 1017 else: 1018 # kwdefauts is a dict, so use a DictGuardManager 1019 kwdefaults = base_example_value.__kwdefaults__ 1020 assert base_source_name is not None 1021 kw_source = base_source_name + ".__kwdefaults__" 1022 1023 # kwdefaults is a dict. No need to guard on dict order. 1024 dict_mgr = base_guard_manager.func_kwdefaults_manager( 1025 source=kw_source, 1026 example_value=kwdefaults, 1027 guard_manager_enum=GuardManagerType.GUARD_MANAGER, 1028 ) 1029 assert not isinstance(dict_mgr, DictGuardManager) 1030 1031 out = dict_mgr.dict_getitem_manager( 1032 key=source.idx_key, 1033 source=source_name, 1034 example_value=example_value, 1035 guard_manager_enum=guard_manager_enum, 1036 ) 1037 elif istype(source, NumpyTensorSource): 1038 assert base_guard_manager # to make mypy happy 1039 out = base_guard_manager.lambda_manager( 1040 python_lambda=from_numpy, 1041 source=source_name, 1042 example_value=example_value, 1043 guard_manager_enum=guard_manager_enum, 1044 ) 1045 elif istype(source, SubclassAttrListSource): 1046 assert base_guard_manager # to make mypy happy 1047 out = base_guard_manager.lambda_manager( 1048 python_lambda=lambda x: x.__tensor_flatten__()[0], 1049 source=source_name, 1050 example_value=example_value, 1051 guard_manager_enum=guard_manager_enum, 1052 ) 1053 elif istype(source, FlattenScriptObjectSource): 1054 assert base_guard_manager # to make mypy happy 1055 out = base_guard_manager.lambda_manager( 1056 python_lambda=lambda x: x.__obj_flatten__(), 1057 source=source_name, 1058 example_value=example_value, 1059 guard_manager_enum=guard_manager_enum, 1060 ) 1061 elif istype(source, ScriptObjectQualifiedNameSource): 1062 assert base_guard_manager # to make mypy happy 1063 out = base_guard_manager.lambda_manager( 1064 python_lambda=lambda x: x._type().qualified_name(), 1065 source=source_name, 1066 example_value=example_value, 1067 guard_manager_enum=guard_manager_enum, 1068 ) 1069 elif istype(source, AttrProxySource): 1070 assert base_guard_manager # to make mypy happy 1071 out = base_guard_manager.lambda_manager( 1072 python_lambda=lambda x: x.get_base(), 1073 source=source_name, 1074 example_value=example_value, 1075 guard_manager_enum=guard_manager_enum, 1076 ) 1077 elif istype(source, TupleIteratorGetItemSource): 1078 assert base_guard_manager # to make mypy happy 1079 out = base_guard_manager.tuple_iterator_getitem_manager( 1080 index=source.index, 1081 source=source_name, 1082 example_value=example_value, 1083 guard_manager_enum=guard_manager_enum, 1084 ) 1085 elif isinstance(source, ConstDictKeySource): 1086 if not isinstance(base_guard_manager, DictGuardManager): 1087 raise AssertionError( 1088 "ConstDictKeySource can only work on DictGuardManager" 1089 ) 1090 out = base_guard_manager.get_key_manager( 1091 index=source.index, 1092 source=source_name, 1093 example_value=example_value, 1094 guard_manager_enum=guard_manager_enum, 1095 ) 1096 elif isinstance(source, WeakRefCallSource): 1097 assert base_guard_manager # to make mypy happy 1098 out = base_guard_manager.weakref_call_manager( 1099 source=source_name, 1100 example_value=example_value, 1101 guard_manager_enum=guard_manager_enum, 1102 ) 1103 else: 1104 raise AssertionError( 1105 f"missing guard manager builder {source} - {source.name()}" 1106 ) 1107 1108 self._cached_guard_managers[source.name()] = out 1109 return out 1110 1111 def get_guard_manager(self, guard: Guard): 1112 return self.get_guard_manager_from_source(guard.originating_source) 1113 1114 def add_python_lambda_leaf_guard_to_root( 1115 self, 1116 code_parts, 1117 verbose_code_parts, 1118 closure_vars=CLOSURE_VARS, 1119 is_epilogue=True, 1120 ): 1121 # Adds a lambda leaf guard to the root guard manager. It wraps the 1122 # code_parts in a function object which is then passed on to the leaf 1123 # guard. 1124 make_guard_fn_args = ", ".join(closure_vars.keys()) 1125 guard_body, pycode = build_guard_function(code_parts, make_guard_fn_args) 1126 out: Dict[str, Any] = {} 1127 globals_for_guard_fn = {"G": self.scope["G"]} 1128 exec(pycode, globals_for_guard_fn, out) 1129 guard_fn = out["___make_guard_fn"](*closure_vars.values()) 1130 assert self.guard_manager # to make mypy happy 1131 if is_epilogue: 1132 # Epilogue guards are run after all the other guards have finished. 1133 # If epilogue guards contain a getattr or getitem access, one of the 1134 # other guards would fail preventing the epilogue guards to run. 1135 self.guard_manager.root.add_epilogue_lambda_guard( 1136 guard_fn, verbose_code_parts 1137 ) 1138 else: 1139 self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts) 1140 1141 # Warning: use this with care! This lets you access what the current 1142 # value of the value you are guarding on is. You probably don't want 1143 # to actually durably save this value though (because it's specific 1144 # to this frame!) Instead, you should be reading out some property 1145 # (like its type) which is what you permanently install into the 1146 # guard code. 1147 def get(self, name: str) -> Any: 1148 return eval(name, self.scope, CLOSURE_VARS) 1149 1150 # Registers the usage of the source name referenced by the 1151 # string (or stored in the Guard) as being guarded upon. It's important 1152 # to call this before generating some code that makes use of 'guard', 1153 # because without this call, we won't actually bind the variable 1154 # you reference in the actual guard closure (oops!) 1155 def arg_ref(self, guard: Union[str, Guard]) -> str: 1156 name: str 1157 if isinstance(guard, str): 1158 name = guard 1159 else: 1160 name = guard.name 1161 base = strip_getattr_getitem(strip_function_call(name)) 1162 if base not in self.argnames: 1163 if re.match(r"[a-zA-Z0-9_]+", base): 1164 if re.match(r"^\d+$", base): 1165 log.warning("invalid var name: %s", guard) 1166 self.argnames.append(base) 1167 1168 return name 1169 1170 def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn): 1171 attr_source = AttrSource(guard.originating_source, attr_name) 1172 # Copy the stack info 1173 new_guard = Guard( 1174 attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack 1175 ) 1176 new_guard.create(self) 1177 1178 # Note: the order of the guards in this file matters since we sort guards on the same object by lineno 1179 def HASATTR(self, guard: Guard): 1180 source = guard.originating_source 1181 if isinstance(source, NNModuleSource): 1182 source = source.base 1183 assert isinstance(source, AttrSource), f"invalid source {guard.name}" 1184 base_source = source.base 1185 base = base_source.name() 1186 attr = source.member 1187 1188 ref = self.arg_ref(base) 1189 val = hasattr(self.get(base), attr) 1190 code = None 1191 if val: 1192 code = f"hasattr({ref}, {attr!r})" 1193 else: 1194 code = f"not hasattr({ref}, {attr!r})" 1195 self._set_guard_export_info( 1196 guard, [code], provided_guarded_object=self.get(base) 1197 ) 1198 1199 if config.enable_cpp_guard_manager: 1200 base_manager = self.get_guard_manager_from_source(base_source) 1201 if val: 1202 # Just install a getattr manager. GetAttrGuardAccessor itself 1203 # acts as hasattr guard. 1204 example_value = self.get(source.name()) 1205 base_example_value = self.get(base) 1206 guard_manager_enum = self.get_guard_manager_type(source, example_value) 1207 1208 # if the base value is nn.Module, check if we can speedup the 1209 # guard by going through __dict__ attrs. 1210 if ( 1211 isinstance(base_example_value, torch.nn.Module) 1212 and get_custom_getattr(base_example_value) 1213 is unpatched_nn_module_getattr 1214 ): 1215 return self.getattr_on_nn_module( 1216 source, 1217 base_manager, 1218 base_example_value, 1219 example_value, 1220 base, 1221 source.name(), 1222 guard_manager_enum, 1223 ) 1224 else: 1225 base_manager.getattr_manager( 1226 attr=attr, 1227 source=guard.name, 1228 example_value=example_value, 1229 guard_manager_enum=guard_manager_enum, 1230 ) 1231 else: 1232 base_manager.add_no_hasattr_guard( 1233 attr, get_verbose_code_parts(code, guard) 1234 ) 1235 else: 1236 self._produce_guard_code(guard, [code]) 1237 1238 def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: 1239 assert attr is not None 1240 ref = self.arg_ref(guard) 1241 val = self.get(guard.name) 1242 assert isinstance(val, torch.nn.Module) 1243 1244 base_manager = self.get_guard_manager(guard) 1245 1246 mod_dict_source = f"{guard.name}.__dict__" 1247 mod_generic_dict_manager = base_manager.get_generic_dict_manager( 1248 source=mod_dict_source, 1249 example_value=val.__dict__, 1250 guard_manager_enum=GuardManagerType.GUARD_MANAGER, 1251 ) 1252 1253 code = f"not ___dict_contains({attr!r}, {ref}.__dict__)" 1254 mod_generic_dict_manager.add_dict_contains_guard( 1255 False, attr, get_verbose_code_parts(code, guard) 1256 ) 1257 1258 def TYPE_MATCH(self, guard: Guard) -> None: 1259 # ___check_type_id is same as `id(type(x)) == y` 1260 t = type(self.get(guard.name)) 1261 obj_id = self.id_ref(t) 1262 code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" 1263 self._set_guard_export_info(guard, [code]) 1264 1265 if config.enable_cpp_guard_manager: 1266 self.get_guard_manager(guard).add_type_match_guard( 1267 obj_id, get_verbose_code_parts(code, guard) 1268 ) 1269 else: 1270 self._produce_guard_code(guard, [code]) 1271 1272 def DICT_VERSION(self, guard: Guard): 1273 # ___check_dict_version is same as `dict_version(x) == y` 1274 ref = self.arg_ref(guard) 1275 val = self.get(guard.name) 1276 version = dict_version(self.get(guard.name)) 1277 code = f"___dict_version({ref}) == {version}" 1278 self._set_guard_export_info(guard, [code]) 1279 1280 if config.enable_cpp_guard_manager: 1281 # TODO(anijain2305) - Delete this when DictGuardManager uses tags 1282 # for dicts. 1283 self.get_guard_manager(guard).add_dict_version_guard( 1284 val, get_verbose_code_parts(code, guard) 1285 ) 1286 else: 1287 self._produce_guard_code(guard, [code]) 1288 1289 def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): 1290 dict_ref = self.arg_ref(guard) 1291 1292 maybe_not = "not " if invert else "" 1293 code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})" 1294 self._set_guard_export_info(guard, [code]) 1295 1296 if config.enable_cpp_guard_manager: 1297 self.get_guard_manager(guard).add_dict_contains_guard( 1298 not invert, key, get_verbose_code_parts(code, guard) 1299 ) 1300 else: 1301 self._produce_guard_code(guard, [code]) 1302 1303 def ID_MATCH(self, guard: Guard): 1304 # ___check_obj_id is same as `id(x) == y` 1305 if isinstance(guard.originating_source, TypeSource): 1306 # optional optimization to produce cleaner/faster guard code 1307 return self.TYPE_MATCH( 1308 Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH) # type: ignore[arg-type] 1309 ) 1310 1311 ref = self.arg_ref(guard) 1312 val = self.get(guard.name) 1313 id_val = self.id_ref(val) 1314 code = f"___check_obj_id({ref}, {id_val})" 1315 self._set_guard_export_info(guard, [code]) 1316 1317 if config.enable_cpp_guard_manager: 1318 self.get_guard_manager(guard).add_id_match_guard( 1319 id_val, get_verbose_code_parts(code, guard) 1320 ) 1321 else: 1322 self._produce_guard_code(guard, [code]) 1323 1324 # Keep track of ID_MATCH'd objects. This will be used to modify the 1325 # cache size logic 1326 if isinstance(guard.originating_source, LocalSource): 1327 # TODO(anijain2305) - This is currently restricted to nn.Module objects 1328 # because many other ID_MATCH'd objects fail - like DeviceMesh. 1329 # Increase the scope of ID_MATCH'd objects. 1330 if isinstance(val, torch.nn.Module): 1331 local_name = guard.originating_source.local_name 1332 weak_id = self.lookup_weakrefs(val) 1333 if weak_id is not None: 1334 self.id_matched_objs[local_name] = weak_id 1335 1336 def NOT_NONE_MATCH(self, guard: Guard, value=None): 1337 ref = self.arg_ref(guard) 1338 val = self.get(guard.name) 1339 assert isinstance(val, torch.Tensor) 1340 code = f"{ref} is not None" 1341 self._set_guard_export_info(guard, [code]) 1342 1343 if config.enable_cpp_guard_manager: 1344 self.get_guard_manager(guard).add_not_none_guard( 1345 get_verbose_code_parts(code, guard) 1346 ) 1347 else: 1348 self._produce_guard_code(guard, [code]) 1349 1350 def NAME_MATCH(self, guard: Guard): 1351 self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) 1352 1353 def DATA_PTR_MATCH(self, guard: Guard): 1354 # Add a type check. C++ guard has the type check internally, so only 1355 # enable it for Python guards. 1356 if not config.enable_cpp_guard_manager: 1357 self.TYPE_MATCH(guard) 1358 1359 obj = self.get(guard.name) 1360 code = f"{self.arg_ref(guard)}.data_ptr() == {obj.data_ptr()}" 1361 self._set_guard_export_info(guard, [code]) 1362 1363 if config.enable_cpp_guard_manager: 1364 self.get_guard_manager(guard).add_data_ptr_guard( 1365 obj, get_verbose_code_parts(code, guard) 1366 ) 1367 else: 1368 self._produce_guard_code(guard, [code]) 1369 1370 def DUAL_LEVEL(self, guard: Guard): 1371 # Invalidate dual level if current dual level is different than the one 1372 # in the fx graph 1373 dual_level = torch.autograd.forward_ad._current_level 1374 code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] 1375 self._set_guard_export_info(guard, [code]) 1376 if config.enable_cpp_guard_manager: 1377 # TODO(anijain2305) - Consider this moving this guard to C++ 1378 forward_ad = torch.autograd.forward_ad 1379 1380 def fn(x): 1381 return forward_ad._current_level == dual_level 1382 1383 assert self.guard_manager # to make mypy happy 1384 self.guard_manager.root.add_lambda_guard( 1385 fn, get_verbose_code_parts(code, guard) 1386 ) 1387 else: 1388 self._produce_guard_code(guard, code) 1389 1390 def FUNCTORCH_STACK_MATCH(self, guard: Guard): 1391 # Invalidate functorch code if current level is different than 1392 # the one when FX graph was generated 1393 cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters() 1394 states = [ci.get_state() for ci in cis] 1395 code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] 1396 self._set_guard_export_info(guard, code) 1397 1398 if config.enable_cpp_guard_manager: 1399 # TODO(anijain2305) - Consider this moving this guard to C++ 1400 compare_fn = torch._functorch.pyfunctorch.compare_functorch_state 1401 1402 def fn(x): 1403 return compare_fn(states) 1404 1405 assert self.guard_manager # to make mypy happy 1406 self.guard_manager.root.add_lambda_guard( 1407 fn, get_verbose_code_parts(code, guard) 1408 ) 1409 else: 1410 self._produce_guard_code(guard, code) 1411 1412 def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard): 1413 value = self.get(guard.name) 1414 original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1]) 1415 if hasattr(value, "__metadata_guard__"): 1416 verify_guard_fn_signature(value) 1417 1418 def metadata_checker(x): 1419 return value.__metadata_guard__( 1420 original_metadata, x.__tensor_flatten__()[1] 1421 ) 1422 1423 else: 1424 1425 def metadata_checker(x): 1426 return x.__tensor_flatten__()[1] == original_metadata 1427 1428 global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" 1429 if config.enable_cpp_guard_manager: 1430 self.get_guard_manager(guard).add_lambda_guard( 1431 metadata_checker, get_verbose_code_parts(global_name, guard) 1432 ) 1433 else: 1434 global_scope = self.get("G") 1435 global_scope[global_name] = metadata_checker 1436 code = [f"{global_name}({self.get(guard.name)})"] 1437 self._produce_guard_code(guard, code) 1438 1439 def EQUALS_MATCH(self, guard: Guard): 1440 ref = self.arg_ref(guard) 1441 val = self.get(guard.name) 1442 t = type(val) 1443 if np: 1444 np_types: Tuple[Type[Any], ...] = ( 1445 np.int8, 1446 np.int16, 1447 np.int32, 1448 np.int64, 1449 np.uint8, 1450 np.uint16, 1451 np.uint32, 1452 np.uint64, 1453 np.float16, 1454 np.float32, 1455 np.float64, 1456 ) 1457 else: 1458 np_types = () 1459 1460 ok_mutable_types = (list, set) 1461 1462 ok_types = tuple( 1463 common_constant_types 1464 | { 1465 type, 1466 tuple, 1467 frozenset, 1468 slice, 1469 range, 1470 torch.Size, 1471 *np_types, 1472 *ok_mutable_types, 1473 } 1474 ) 1475 1476 if torch.distributed.is_available(): 1477 from torch.distributed.device_mesh import DeviceMesh 1478 from torch.distributed.tensor.placement_types import ( 1479 Partial, 1480 Replicate, 1481 Shard, 1482 ) 1483 1484 ok_types = ok_types + ( 1485 Shard, 1486 Replicate, 1487 Partial, 1488 DeviceMesh, 1489 ) 1490 1491 if istype(val, dict): 1492 assert all( 1493 istype(x, ok_types) for x in itertools.chain(val.keys(), val.values()) 1494 ) 1495 else: 1496 assert istype( 1497 val, 1498 ok_types, 1499 ), f"Unexpected type {type(val)}, not in {ok_types}" 1500 1501 # Special case for nan because float("nan") == float("nan") evaluates to False 1502 if istype(val, float) and math.isnan(val): 1503 self.TYPE_MATCH(guard) 1504 code = [] 1505 code.append(f"__math_isnan({ref})") 1506 self._set_guard_export_info(guard, code) 1507 1508 if config.enable_cpp_guard_manager: 1509 self.get_guard_manager(guard).add_lambda_guard( 1510 CLOSURE_VARS["__math_isnan"], get_verbose_code_parts(code, guard) 1511 ) 1512 else: 1513 self._produce_guard_code(guard, code) 1514 return 1515 1516 # Python math library doesn't support complex nan, so we need to use numpy 1517 if istype(val, complex) and np.isnan(val): 1518 self.TYPE_MATCH(guard) 1519 code = [] 1520 code.append(f"__numpy_isnan({ref})") 1521 self._set_guard_export_info(guard, code) 1522 1523 if config.enable_cpp_guard_manager: 1524 self.get_guard_manager(guard).add_lambda_guard( 1525 CLOSURE_VARS["__numpy_isnan"], get_verbose_code_parts(code, guard) 1526 ) 1527 else: 1528 self._produce_guard_code(guard, code) 1529 return 1530 1531 if config.enable_cpp_guard_manager: 1532 # Construct a debug string to put into the c++ equals match guard. 1533 code = [f"{ref} == {val!r}"] 1534 if istype(val, ok_mutable_types): 1535 # C++ guards perform a pointer equality check to speedup guards, but the assumption is that the object 1536 # is mutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the 1537 # pointer equality check. 1538 val = deepcopy(val) 1539 self.get_guard_manager(guard).add_equals_match_guard( 1540 val, get_verbose_code_parts(code, guard) 1541 ) 1542 self._set_guard_export_info(guard, code) 1543 return 1544 1545 code = [] 1546 1547 # If matching equality against list/tuple, we must also check that 1548 # the internal types match. (TODO: what about nested lists?) 1549 if istype(val, (list, tuple)): 1550 # NB: SEQUENCE_LENGTH takes care of the outer __check_type_id test 1551 self.SEQUENCE_LENGTH(guard) 1552 1553 for idx, elem in enumerate(val): 1554 code.append( 1555 f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})" 1556 ) 1557 else: 1558 # Add type check to prevent equality check between tensor and non-tensor. 1559 self.TYPE_MATCH(guard) 1560 1561 if istype(val, torch.Size): 1562 val = tuple(val) 1563 1564 # Code object can not be compared against their string representation 1565 # I.e `eval(f"{compile('2+2','','exec')!r}")` raises SyntaxError 1566 assert not istype(val, types.CodeType) 1567 1568 # TODO: It feels like it would be better to just implement our own 1569 # equality test in C that handles all of the necessary type checking 1570 # and NaN tests 1571 code.append(f"{ref} == {val!r}") 1572 self._produce_guard_code(guard, code) 1573 self._set_guard_export_info(guard, code) 1574 1575 def CONSTANT_MATCH(self, guard: Guard): 1576 val = self.get(guard.name) 1577 if istype(val, (bool, type(None), types.CodeType)): 1578 self.ID_MATCH(guard) 1579 else: 1580 self.EQUALS_MATCH(guard) 1581 1582 def NN_MODULE(self, guard: Guard): 1583 self.ID_MATCH(guard) 1584 val = self.get(guard.name) 1585 if hasattr(val, "training"): 1586 assert istype(val.training, bool) 1587 self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) 1588 else: 1589 exc.unimplemented(f"Guard setup for uninitialized class {type(val)}") 1590 1591 def FUNCTION_MATCH(self, guard: Guard): 1592 """things like torch.add and user defined functions""" 1593 return self.ID_MATCH(guard) 1594 1595 def CLOSURE_MATCH(self, guard: Guard): 1596 """matches a closure by __code__ id.""" 1597 val = self.get(guard.name) 1598 # Strictly only want user-defined functions 1599 if type(val) == types.FunctionType and hasattr(val, "__code__"): 1600 self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) 1601 self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) 1602 else: 1603 self.FUNCTION_MATCH(guard) 1604 1605 def BUILTIN_MATCH(self, guard: Guard): 1606 return self.FUNCTION_MATCH(guard) 1607 1608 def PYMODULE_MATCH(self, guard: Guard): 1609 return self.FUNCTION_MATCH(guard) 1610 1611 def SEQUENCE_LENGTH(self, guard): 1612 # This guard is used to check lenght of PySequence objects like list, 1613 # tuple, collections.deque etc 1614 ref = self.arg_ref(guard) 1615 value = self.get(guard.name) 1616 t = type(value) 1617 1618 if not (config.enable_cpp_guard_manager and isinstance(value, dict)): 1619 # C++ DICT_LENGTH checks for type 1620 self.TYPE_MATCH(guard) 1621 1622 code = [] 1623 if len(value) == 0: 1624 code.append(f"not {ref}") 1625 else: 1626 code.append(f"len({ref}) == {len(value)}") 1627 1628 self._set_guard_export_info(guard, code) 1629 if config.enable_cpp_guard_manager: 1630 if isinstance(value, dict): 1631 self.get_guard_manager(guard).add_dict_length_check_guard( 1632 len(value), get_verbose_code_parts(code, guard) 1633 ) 1634 else: 1635 self.get_guard_manager(guard).add_length_check_guard( 1636 len(value), get_verbose_code_parts(code, guard) 1637 ) 1638 else: 1639 self._produce_guard_code(guard, code) 1640 1641 def TUPLE_ITERATOR_LEN(self, guard): 1642 ref = self.arg_ref(guard) 1643 value = self.get(guard.name) 1644 t = type(value) 1645 1646 if not config.enable_cpp_guard_manager: 1647 # C++ guard already checks the type 1648 self.TYPE_MATCH(guard) 1649 1650 code = [] 1651 code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}") 1652 self._set_guard_export_info(guard, code) 1653 1654 if config.enable_cpp_guard_manager: 1655 t = type(value) 1656 obj_id = self.id_ref(t) 1657 1658 self.get_guard_manager(guard).add_tuple_iterator_length_guard( 1659 tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) 1660 ) 1661 else: 1662 self._produce_guard_code(guard, code) 1663 1664 # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards 1665 def DUPLICATE_INPUT(self, guard, source_b): 1666 ref_a = self.arg_ref(guard) 1667 ref_b = self.arg_ref(source_b.name()) 1668 1669 if is_from_optimizer_source( 1670 guard.originating_source 1671 ) or is_from_optimizer_source(source_b): 1672 return 1673 1674 code = [f"{ref_b} is {ref_a}"] 1675 self._set_guard_export_info(guard, code) 1676 1677 if config.enable_cpp_guard_manager: 1678 # Check that the guard has not been inserted already 1679 key = (ref_a, ref_b) 1680 if key in self._cached_duplicate_input_guards: 1681 return 1682 self._cached_duplicate_input_guards.add((ref_a, ref_b)) 1683 self._cached_duplicate_input_guards.add((ref_b, ref_a)) 1684 1685 install_object_aliasing_guard( 1686 self.get_guard_manager(guard), 1687 self.get_guard_manager_from_source(source_b), 1688 get_verbose_code_parts(code, guard), 1689 ) 1690 else: 1691 self._produce_guard_code(guard, code) 1692 1693 def DICT_KEYS(self, guard): 1694 # Guard on the keys and their order 1695 ref = self.arg_ref(guard) 1696 value = self.get(guard.name) 1697 t = type(value) 1698 1699 self.TYPE_MATCH(guard) 1700 code = [] 1701 any_key_is_id = any(key_is_id(k) for k in value.keys()) 1702 const_keys_repr = dict_keys_repr( 1703 key_to_id(value), 1704 local=is_from_local_source(guard.originating_source), 1705 ) 1706 if any_key_is_id: 1707 code.append(f"___key_to_id({ref}) == {const_keys_repr}") 1708 else: 1709 code.append(f"list({ref}.keys()) == {const_keys_repr}") 1710 1711 self._set_guard_export_info(guard, code) 1712 if config.enable_cpp_guard_manager: 1713 if self.requires_key_order_guarding(guard.originating_source): 1714 self.guard_on_dict_keys_and_order(value, guard) 1715 else: 1716 self.guard_on_dict_keys_and_ignore_order(value, guard) 1717 else: 1718 self._produce_guard_code(guard, code) 1719 1720 def WEAKREF_ALIVE(self, guard): 1721 code = [f"{self.arg_ref(guard)} is not None"] 1722 1723 self._set_guard_export_info(guard, code) 1724 if config.enable_cpp_guard_manager: 1725 self.get_guard_manager(guard).add_not_none_guard( 1726 get_verbose_code_parts(code, guard) 1727 ) 1728 else: 1729 self._produce_guard_code(guard, code) 1730 1731 def DICT_CONST_KEYS(self, guard): 1732 """Constant keys match""" 1733 ref = self.arg_ref(guard) 1734 value = self.get(guard.name) 1735 t = type(value) 1736 1737 if not config.enable_cpp_guard_manager: 1738 # DictGuardManager supports TYPE_MATCH internally 1739 self.TYPE_MATCH(guard) 1740 1741 code = [] 1742 code.append(f"list({ref}.keys()) == {list(value.keys())!r}") 1743 self._set_guard_export_info(guard, code) 1744 1745 if config.enable_cpp_guard_manager: 1746 if self.requires_key_order_guarding(guard.originating_source): 1747 self.guard_on_dict_keys_and_order(value, guard) 1748 else: 1749 self.guard_on_dict_keys_and_ignore_order(value, guard) 1750 else: 1751 self._produce_guard_code(guard, code) 1752 1753 def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): 1754 """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" 1755 if config.skip_nnmodule_hook_guards: 1756 # This is unsafe if you add/remove a hook on nn module variable 1757 return 1758 self.SEQUENCE_LENGTH(guard) 1759 1760 def OBJECT_MUTATION(self, guard: Guard): 1761 mutation_guard.watch(self.get(guard.name), self.check_fn_manager) 1762 1763 def GRAD_MODE(self, guard: Guard): 1764 pass # we always guard on this via GlobalStateGuard() 1765 1766 def DETERMINISTIC_ALGORITHMS(self, guard: Guard): 1767 pass # we always guard on this via GlobalStateGuard() 1768 1769 def TORCH_FUNCTION_STATE(self, guard: Guard): 1770 pass # we always guard on this via GlobalStateGuard() 1771 1772 def FSDP_TRAINING_STATE(self, guard: Guard): 1773 pass # we always guard on this via GlobalStateGuard() 1774 1775 def DEFAULT_DEVICE(self, guard: Guard): 1776 """Guard on CURRENT_DEVICE per torch.utils._device""" 1777 assert guard.source is GuardSource.GLOBAL 1778 import torch.utils._device as m 1779 1780 code = [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"] 1781 self._set_guard_export_info(guard, code) 1782 1783 if config.enable_cpp_guard_manager: 1784 self.get_guard_manager(guard).add_default_device_guard( 1785 get_verbose_code_parts(code, guard) 1786 ) 1787 else: 1788 self._produce_guard_code(guard, code) 1789 1790 def SHAPE_ENV(self, guard: Guard): 1791 # Let's handle ShapeEnv guards. To do this, we will resolve 1792 # shape variables to sources from tracked_fakes. This must happen after 1793 # tensor checks. 1794 assert guard.name == "" 1795 output_graph = self.check_fn_manager.output_graph 1796 # NB: self.output_graph can be None in the debug_nops tests 1797 fs = output_graph.tracked_fakes 1798 input_contexts = [a.symbolic_context for a in fs] 1799 1800 def get_sources(t_id, dim): 1801 # Looks up base sources mapped to a tensor id and uses them to create 1802 # sources for the corresponding tensor dimension. 1803 return [ 1804 TensorPropertySource(source, TensorProperty.SIZE, dim) 1805 for source in output_graph.tracked_fakes_id_to_source[t_id] 1806 ] 1807 1808 if output_graph.export_constraints: 1809 names: Dict[str, Tuple[int, int]] = {} 1810 source_pairs: List[Tuple[Source, Source]] = [] 1811 derived_equalities: List[ # type: ignore[type-arg] 1812 Tuple[Source, Union[Source, Symbol], Callable] 1813 ] = [] 1814 phantom_symbols: Dict[str, Symbol] = {} 1815 for constraint in output_graph.export_constraints: 1816 if constraint.t_id in output_graph.tracked_fakes_id_to_source: 1817 torch.export.dynamic_shapes._process_equalities( 1818 constraint, 1819 get_sources, 1820 output_graph.shape_env, 1821 names, 1822 source_pairs, 1823 derived_equalities, 1824 phantom_symbols, 1825 ) 1826 else: 1827 log.warning("Untracked tensor used in export constraints") 1828 equalities_inputs = EqualityConstraint( 1829 source_pairs=source_pairs, 1830 derived_equalities=derived_equalities, 1831 phantom_symbols=list(phantom_symbols.values()), 1832 warn_only=False, 1833 ) 1834 else: 1835 equalities_inputs = None 1836 guards = output_graph.shape_env.produce_guards( 1837 [a.fake for a in fs], 1838 [a.source for a in fs], 1839 input_contexts=input_contexts, 1840 equalities_inputs=equalities_inputs, 1841 source_ref=self.source_ref, 1842 # Export keeps static. 1843 ignore_static=(not self.check_fn_manager.output_graph.export), 1844 ) 1845 # When exporting, we may work with the shape constraints some more in 1846 # postprocessing, so don't freeze yet 1847 if not self.check_fn_manager.output_graph.export: 1848 output_graph.shape_env.freeze() 1849 1850 for shape_guard in guards: 1851 self._set_guard_export_info(guard, [shape_guard]) 1852 1853 if config.enable_cpp_guard_manager: 1854 # Install all the symbolic guards in one lambda guard. These are run 1855 # at the very end of the RootGuardManager via epilogue guards. 1856 # TODO(anijain2305,williamwen42) - Consider moving this to C++. 1857 code_parts = guards 1858 self.add_python_lambda_leaf_guard_to_root( 1859 code_parts, 1860 get_verbose_code_parts(code_parts, guard), 1861 closure_vars={**SYMPY_INTERP, **CLOSURE_VARS}, 1862 ) 1863 else: 1864 for shape_guard in guards: 1865 self._produce_guard_code(guard, [shape_guard], shape_env=True) 1866 1867 def TENSOR_MATCH(self, guard: Guard, value=None): 1868 # For FSDP modules, we can skip guards on nn module tensors because FSDP 1869 # eager assumes that the params are unchanged once the model is wrapped. 1870 if guard.is_fsdp_module(): 1871 return 1872 1873 # For tensors that are part of the Dynamo extracted Fx graph module, an 1874 # ID_MATCH suffices. Once we turn on inline_inbuilt_nn_modules, these 1875 # will be lifted as inputs and have a TENSOR_MATCH guard. 1876 # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads 1877 # to a new tensor everytime and therefore id differs. 1878 if ( 1879 guard.is_specialized_nn_module() 1880 and not isinstance(guard.originating_source, NumpyTensorSource) 1881 ) or match_on_id_for_tensor(guard): 1882 self.ID_MATCH(guard) 1883 else: 1884 if isinstance(value, TensorWeakRef): 1885 value = value() 1886 1887 value = value if value is not None else self.get(guard.name) 1888 assert isinstance(value, torch.Tensor) 1889 1890 tensor_name = self.arg_ref(guard) 1891 # [Note - On Export Tensor Guards] 1892 # 1893 # In eager mode, tensor guards are evaluated through C++, in guards.cpp 1894 # see [Note - On Eager Tensor Guards] for more info. 1895 # 1896 # In export mode, we instead maintain parallel logic between C++ and python 1897 # here, with an exception of checking the dispatch key - with the idea that a dispatch key 1898 # is an entirely runtime notion that would make no sense to keep in an exported graph. 1899 # 1900 # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although 1901 # not entirely true. 1902 # For example, suppose one of the input tensors had the negative dispatch key. 1903 # You should end up with a graph that is specialized for tensors that have a negative dispatch key. 1904 # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated. 1905 # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't 1906 # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key. 1907 # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported 1908 # subset of keys during export. 1909 # 1910 # The list of tensor fields and calls we care about can be found in `terms` below. 1911 # TODO(voz): We are missing storage offset in all our tensor guards? 1912 code: List[str] = [] 1913 if self.check_fn_manager.output_graph.export: 1914 self.TYPE_MATCH(guard) 1915 terms = [ 1916 "dtype", 1917 "device", 1918 "requires_grad", 1919 "ndimension()", 1920 ] 1921 1922 for term in terms: 1923 real_value = self.get(tensor_name + "." + term) 1924 if istype(real_value, (torch.device, torch.dtype)): 1925 # copy pasted from EQUALS_MATCH 1926 code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}") 1927 else: 1928 code.append(f"{tensor_name}.{term} == {real_value}") 1929 else: 1930 self.tensor_check_examples.append(value) 1931 self.tensor_check_names.append(tensor_name) 1932 self.tensor_check_guards.append(guard) 1933 1934 if config.enable_cpp_guard_manager: 1935 guard_manager = self.get_guard_manager(guard) 1936 # Keep track of all the tensor guard managers to insert 1937 # NoAliasing check at the end. 1938 self.tensor_check_guard_managers.append(guard_manager) 1939 1940 output_graph = self.check_fn_manager.output_graph 1941 metadata = output_graph.input_source_to_sizes_strides[ 1942 guard.originating_source 1943 ] 1944 size = convert_to_concrete_values(metadata["size"]) 1945 stride = convert_to_concrete_values(metadata["stride"]) 1946 1947 verbose_code_parts = get_verbose_code_parts( 1948 get_tensor_guard_code_part(value, tensor_name, size, stride), 1949 guard, 1950 ) 1951 guard_manager.add_tensor_match_guard( 1952 value, 1953 size, 1954 stride, 1955 tensor_name, 1956 verbose_code_parts, 1957 ) 1958 1959 # A frame is valid for reuse with dynamic dimensions if the new 1960 # (user-requested) dynamic dimensions are a subset of the old 1961 # (already compiled) dynamic dimensions. 1962 # 1963 # It's a little non-obvious why you'd want this: in particular, 1964 # if an already compiled frame matches all of the guards, why 1965 # not just use it, why force a recompile? 1966 # 1967 # We force it for two reasons: 1968 # 1969 # - The user *required* us to compile with a new dynamic dimension, 1970 # we should not ignore that and serve up the old, specialized 1971 # frame. Listen to the user! 1972 # 1973 # - In fact, we are obligated to *raise an error* if we fail to 1974 # make the requested dimension dynamic. If we don't 1975 # recompile, we can't tell if that dimension can actually be 1976 # made dynamic. 1977 # 1978 # If the new dynamic dims are a subset of the old, we already know 1979 # we can make them dynamic (since we made them dynamic in old). 1980 # This is slightly unsound, because maybe your input size is 1981 # [s0, s0, s1] and so you can do it dynamic if you say dynamic 1982 # dims {0, 1, 2} but you can't if you only do {0, 2} (because now 1983 # the second s0 is specialized). But we're not entirely sure if 1984 # this is a good idea anyway lol... (if you want to try removing 1985 # this logic, be my guest! -- ezyang 2024) 1986 # 1987 assert guard.source is not None 1988 static, reason = tensor_always_has_static_shape( 1989 value, is_tensor=True, tensor_source=guard.originating_source 1990 ) 1991 1992 if not static: 1993 if hasattr(value, "_dynamo_dynamic_indices"): 1994 dynamic_indices = value._dynamo_dynamic_indices 1995 code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # noqa: B950 1996 code.append(code_part) 1997 if config.enable_cpp_guard_manager: 1998 self.get_guard_manager(guard).add_dynamic_indices_guard( 1999 dynamic_indices, get_verbose_code_parts(code_part, guard) 2000 ) 2001 # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of 2002 # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled. 2003 else: 2004 code_part = ( 2005 f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" 2006 ) 2007 code.append(code_part) 2008 if config.enable_cpp_guard_manager: 2009 self.get_guard_manager(guard).add_no_hasattr_guard( 2010 "_dynamo_dynamic_indices", 2011 get_verbose_code_parts(code_part, guard), 2012 ) 2013 if len(code) > 0: 2014 self._set_guard_export_info(guard, code) 2015 if not config.enable_cpp_guard_manager: 2016 self._produce_guard_code(guard, code) 2017 2018 # A util that appends guarded code 2019 def _produce_guard_code(self, guard, code_list, shape_env=False): 2020 assert not config.enable_cpp_guard_manager 2021 if shape_env: 2022 self.shape_env_code.append(GuardCodeList(code_list, guard)) 2023 else: 2024 self.code.append(GuardCodeList(code_list, guard)) 2025 2026 # A util that in the case of export, adds data onto guards 2027 def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None): 2028 # WARNING: It is important that cur_frame/caller do NOT stay in 2029 # the current frame, because they will keep things live longer 2030 # than they should. See TestMisc.test_release_module_memory 2031 cur_frame = currentframe() 2032 assert cur_frame is not None 2033 caller = cur_frame.f_back 2034 del cur_frame 2035 assert caller is not None 2036 func_name = getframeinfo(caller)[2] 2037 del caller 2038 # We use func_name for export, so might as well get a nice defensive check out of it 2039 assert func_name in dir( 2040 self.__class__ 2041 ), f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}" 2042 2043 # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD) 2044 if provided_guarded_object is None: 2045 name_valid = guard.name is not None and guard.name != "" 2046 2047 guarded_object = self.get(guard.name) if name_valid else None 2048 else: 2049 guarded_object = provided_guarded_object 2050 2051 guarded_object_type = ( 2052 weakref.ref(type(guarded_object)) if guarded_object is not None else None 2053 ) 2054 obj_ref = None 2055 # Not necessary to have weakref for Enum type, but there is a bug that 2056 # makes hasattr(guarded_object.__class__, "__weakref__") return True. 2057 if hasattr(guarded_object.__class__, "__weakref__") and not isinstance( 2058 guarded_object, enum.Enum 2059 ): 2060 obj_ref = weakref.ref(guarded_object) 2061 2062 guard.set_export_info( 2063 func_name, 2064 guarded_object_type, 2065 code_list, 2066 obj_ref, 2067 ) 2068 2069 2070# Common Sub-Expression Elimination for Python expressions. 2071# 2072# There are 2 steps to this pass: 2073# 1. Count the frequency of each sub-expression (i.e. inner 2074# node in the AST tree) 2075# 2076# 2. Replace those that occur more than once by a fresh variable 'v'. 2077# 'v' will be defined in the 'preface' list (output argument to 2078# 'NodeTransformer') 2079# 2080# NB: the use of 'ast.unparse' while visiting the nodes makes this pass 2081# quadratic on the depth of the tree. 2082# 2083# NB: this pass creates a new variable for each AST node that is repeated 2084# more than 'USE_THRESHOLD'. e.g. if 'a.b.c.d' is used 10 times, 'a.b.c' 2085# and 'a.b' are also used 10 times. So, there will be a new variable for 2086# each of them. 2087class PyExprCSEPass: 2088 # Maximum number of times a given expression can be used without being 2089 # replaced by a fresh variable. 2090 USE_THRESHOLD = 1 2091 2092 # Ad-Hoc: AST nodes this pass focuses on. 2093 ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript) 2094 2095 @dataclasses.dataclass 2096 class Config: 2097 expr_count: Dict[str, int] 2098 expr_to_name: Dict[str, str] 2099 2100 class ExprCounter(ast.NodeVisitor): 2101 def __init__(self, config: PyExprCSEPass.Config) -> None: 2102 self._config = config 2103 2104 def visit(self, node: ast.AST) -> Any: 2105 if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): 2106 self._config.expr_count[_ast_unparse(node)] += 1 2107 super().visit(node) 2108 2109 class Replacer(ast.NodeTransformer): 2110 def __init__( 2111 self, 2112 config: PyExprCSEPass.Config, 2113 gen_name: Callable[[], str], 2114 ) -> None: 2115 super().__init__() 2116 self._config = config 2117 self._gen_name = gen_name 2118 self.preface: List[str] = [] 2119 2120 def visit(self, node: ast.AST) -> Any: 2121 if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES): 2122 expr = _ast_unparse(node) 2123 2124 # Replacement only occurs if a given expression is used more 2125 # than once. 2126 if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD: 2127 if expr not in self._config.expr_to_name: 2128 # Parent 'visit' is called so that we CSE the inner expressions first. 2129 # 2130 # The resulting expression is used as right-hand-side of the variable 2131 # assignment. i.e. we are CSE-ing the children before the parents. 2132 # 2133 # Indexing still uses the old 'node', since that's what was counted 2134 # by the 'NodeVisitor'. 2135 node_ = super().visit(node) 2136 expr_ = _ast_unparse(node_) 2137 var_name = self._gen_name() 2138 self.preface.append(f"{var_name} = {expr_}") 2139 self._config.expr_to_name[expr] = var_name 2140 else: 2141 var_name = self._config.expr_to_name[expr] 2142 return ast.Name(var_name, ast.Load()) 2143 2144 return super().visit(node) 2145 2146 def __init__(self) -> None: 2147 self._counter = 0 2148 self._config = self.Config( 2149 expr_count=collections.defaultdict(lambda: 0), expr_to_name={} 2150 ) 2151 2152 def _new_var(self, prefix: str = "_var") -> str: 2153 name = f"{prefix}{self._counter}" 2154 self._counter += 1 2155 return name 2156 2157 def count(self, exprs: List[str]) -> None: 2158 counter = self.ExprCounter(self._config) 2159 for e in exprs: 2160 try: 2161 counter.visit(ast.parse(e)) 2162 except SyntaxError as ex: 2163 log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e) 2164 raise 2165 2166 def replace(self, expr: str) -> Tuple[List[str], str]: 2167 replacer = self.Replacer(self._config, self._new_var) 2168 new_node = replacer.visit(ast.parse(expr)) 2169 return replacer.preface, _ast_unparse(new_node) 2170 2171 2172def must_add_nn_module_guards(guard): 2173 # For config.guard_nn_modules=False, we can skip all the guards that 2174 # originate from inside of nn module except for a few categories. 2175 return ( 2176 # Guard for defaults 2177 isinstance(guard.originating_source, DefaultsSource) 2178 # Guard using dict tags if the config flag is set 2179 or ( 2180 config.guard_nn_modules_using_dict_tags 2181 and guard.create_fn is GuardBuilder.NN_MODULE 2182 ) 2183 ) 2184 2185 2186class DeletedGuardFn: 2187 pass 2188 2189 2190# NB: Naively, you'd expect this to only be a function that produces 2191# the callable that constitutes the guard. However, there is some 2192# delicate handling for invalidating this check function when the 2193# locals/globals get invalidated, so there's some extra state 2194# we have to hold in this manager class. 2195class CheckFunctionManager: 2196 def __init__( 2197 self, 2198 output_graph=None, 2199 guard_fail_fn: Optional[Callable[[GuardFail], None]] = None, 2200 ): 2201 guards = output_graph.guards if output_graph else None 2202 self._weakrefs: Dict[int, ReferenceType[object]] = {} 2203 self.guard_manager = None 2204 if config.enable_cpp_guard_manager: 2205 self.guard_manager = GuardManager() 2206 self.output_graph = output_graph 2207 w_builder = None 2208 2209 self.torch_function_mode_stack = ( 2210 output_graph.torch_function_mode_stack if output_graph else None 2211 ) 2212 2213 def source_ref(source): 2214 guard_source = source.guard_source() 2215 if guard_source is GuardSource.CONSTANT: 2216 # No need to track constants 2217 return source.name() 2218 assert w_builder 2219 r_builder = w_builder() 2220 assert r_builder is not None 2221 return r_builder.arg_ref(source.name()) 2222 2223 builder = GuardBuilder( 2224 self.id_ref, 2225 source_ref, 2226 self.lookup_weakrefs, 2227 output_graph.local_scope, 2228 output_graph.global_scope, 2229 self.guard_manager, 2230 self, 2231 ) 2232 2233 # Break retain cycle. See test_release_scope_memory 2234 def cleanup_builder(weak_b): 2235 b = weak_b() 2236 if b: 2237 b.scope = None 2238 2239 # Break retain cycle. See test_release_input_memory 2240 w_builder = weakref.ref(builder, cleanup_builder) 2241 2242 guard_on_nn_modules = config.guard_nn_modules and justknobs_check( 2243 "pytorch/compiler:guard_nn_modules" 2244 ) 2245 2246 if not justknobs_check("pytorch/compiler:guard_nn_modules"): 2247 log.warning("guard_nn_modules is turned off using justknobs killswitch") 2248 2249 for guard in sorted(guards or [], key=Guard.sort_key): 2250 if ( 2251 not guard_on_nn_modules 2252 and guard.is_specialized_nn_module() 2253 # Default func args must be guarded on. 2254 # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API 2255 and "__defaults__" not in guard.name 2256 and "__kwdefaults__" not in guard.name 2257 and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name) 2258 ): 2259 continue 2260 2261 guard.create(builder) 2262 2263 self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) 2264 2265 # Keep track of weak references of objects with ID_MATCH guard. This 2266 # info is stored alongside optimized_code and check_fn and is used to 2267 # limit the number of cache entries with same ID_MATCH'd object. 2268 # TODO(anijain2305) - Currently this information is stored as an attr on 2269 # the check_fn itself to avoid changing CacehEntry datastructure in 2270 # eval_frame.c. In future, we should probably replace check_fn with a 2271 # queryable data structure such that this information is already present 2272 # in some form. 2273 self.check_fn.id_matched_objs = builder.id_matched_objs 2274 2275 if config.enable_cpp_guard_manager: 2276 # TODO: don't do the string rep, do something more structured here 2277 torch._logging.trace_structured( 2278 "dynamo_cpp_guards_str", payload_fn=lambda: str(self.guard_manager) 2279 ) 2280 guards_log.debug("%s", self.guard_manager) 2281 assert self.guard_manager # to make mypy happy 2282 self.guard_manager.id_matched_objs = builder.id_matched_objs 2283 self.check_fn = self.guard_manager 2284 2285 # Check that the guard returns True. False means that we will always 2286 # recompile. 2287 # TODO(anijain2305, ydwu4) - Skipping export because of following test 2288 # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs 2289 if not output_graph.export: 2290 if not self.guard_manager.check(output_graph.local_scope): 2291 reasons = get_guard_fail_reason_helper( 2292 self.guard_manager, # type: ignore[arg-type] 2293 output_graph.local_scope, 2294 CompileContext.current_compile_id(), 2295 ) 2296 raise AssertionError(f"Guard check failed: {reasons}") 2297 2298 # NB - We have to very careful of cleaning up here. Because of the 2299 # invalidate function, we can create a weakref finalizer that keeps 2300 # `self` alive for very long. Sometimes by mistake, we can run 2301 # invalidate for a type/object (check id_ref method) that Python can 2302 # leak by design, preventing us from calling the finalizer. In that 2303 # case, the `self` will be alive even though the cache entry will be 2304 # deleted (check invalidate method), which can cause a memory leak, 2305 # e.g., not setting output_graph = None can keep hold of nn_modules. 2306 self._weakrefs.clear() 2307 self.output_graph = None 2308 2309 def compile_check_fn(self, builder, guards_out, guard_fail_fn): 2310 # see parallel handling of ".0" / "___implicit0" in _eval_frame.c 2311 largs = builder.argnames 2312 largs += ["**___kwargs_ignored"] 2313 2314 guards_log.debug("GUARDS:") 2315 2316 code_parts = [] 2317 verbose_code_parts = [] 2318 structured_guard_fns: list[Callable[[], dict[str, Any]]] = [] 2319 2320 torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard( 2321 self.torch_function_mode_stack 2322 ) 2323 2324 if config.enable_cpp_guard_manager: 2325 from .variables.torch_function import IGNORED_MODES 2326 2327 # Insert the global_state guard 2328 assert self.guard_manager # to make mypy happy 2329 self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) 2330 2331 self.guard_manager.root.add_torch_function_mode_stack_guard( 2332 self.torch_function_mode_stack, 2333 list(IGNORED_MODES), 2334 ["___check_torch_function_mode_stack()"], 2335 ) 2336 # Clear references to torch_function modes held in the list 2337 self.torch_function_mode_stack = None 2338 else: 2339 # Don't report this guard, it's always the same, useless! 2340 global_guard = "___check_global_state()" 2341 code_parts.append(global_guard) 2342 verbose_code_parts.append(global_guard) 2343 2344 tf_mode_stack_guard = "___check_torch_function_mode_stack()" 2345 code_parts.append(tf_mode_stack_guard) 2346 verbose_code_parts.append(tf_mode_stack_guard) 2347 2348 def add_code_part(code_part, guard, log_only=False): 2349 verbose_code_part = get_verbose_code_part(code_part, guard) 2350 guards_log.debug("%s", verbose_code_part) 2351 2352 structured_guard_fns.append( 2353 lambda: { 2354 "code": code_part, 2355 "stack": structured.from_traceback(guard.stack.summary()) 2356 if guard.stack 2357 else None, 2358 "user_stack": structured.from_traceback(guard.user_stack) 2359 if guard.user_stack 2360 else None, 2361 } 2362 ) 2363 2364 if verbose_guards_log.isEnabledFor(logging.DEBUG): 2365 maybe_stack = "" 2366 maybe_user_stack = "" 2367 if guard is not None: 2368 if guard.stack: 2369 maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}" 2370 if guard.user_stack: 2371 maybe_user_stack = ( 2372 f"\nUser stack:\n{''.join(guard.user_stack.format())}" 2373 ) 2374 verbose_guards_log.debug( 2375 "Guard: %s%s%s", 2376 code_part, 2377 maybe_stack, 2378 maybe_user_stack, 2379 ) 2380 2381 if not log_only: 2382 code_parts.append(code_part) 2383 verbose_code_parts.append(verbose_code_part) 2384 2385 seen = set() 2386 for gcl in builder.code: 2387 for code in gcl.code_list: 2388 if code not in seen: 2389 # If Cpp guard manager is enabled, we don't need to add to 2390 # code_parts. 2391 add_code_part(code, gcl.guard, config.enable_cpp_guard_manager) 2392 seen.add(code) 2393 2394 tensor_check_names = builder.tensor_check_names 2395 check_tensors_fn = None 2396 check_tensors_verbose_fn = None 2397 if tensor_check_names and not config.enable_cpp_guard_manager: 2398 tensor_check_guards = builder.tensor_check_guards 2399 assert ( 2400 not self.output_graph.export 2401 ), "Illegal to set tensor_check_names in export." 2402 tensor_check_examples = builder.tensor_check_examples 2403 2404 dynamic_dims_sizes = [] 2405 dynamic_dims_strides = [] 2406 for t, g in zip(tensor_check_examples, tensor_check_guards): 2407 metadata = self.output_graph.input_source_to_sizes_strides[ 2408 g.originating_source 2409 ] 2410 dynamic_dims_sizes.append(convert_to_concrete_values(metadata["size"])) 2411 dynamic_dims_strides.append( 2412 convert_to_concrete_values(metadata["stride"]) 2413 ) 2414 2415 tensor_guards = TensorGuards( 2416 *tensor_check_examples, 2417 dynamic_dims_sizes=dynamic_dims_sizes, 2418 dynamic_dims_strides=dynamic_dims_strides, 2419 ) 2420 check_tensors_fn = tensor_guards.check 2421 check_tensors_verbose_fn = tensor_guards.check_verbose 2422 tensor_check_args = ", ".join( 2423 tensor_check_names + ["tensor_check_names=tensor_check_names"] 2424 ) 2425 # Do this manually, to un-stagger the guards in log message 2426 code_parts.append(f"___check_tensors({tensor_check_args})") 2427 verbose_code_parts.append(f"___check_tensors({tensor_check_args})") 2428 2429 for i, name in enumerate(tensor_check_names): 2430 # This is a copy of what guards.cpp checks against 2431 # Keep this in sync with TensorCheck constructor 2432 t = tensor_check_examples[i] 2433 sizes = dynamic_dims_sizes[i] 2434 strides = dynamic_dims_strides[i] 2435 code_part = get_tensor_guard_code_part(t, name, sizes, strides) 2436 add_code_part(code_part, tensor_check_guards[i], log_only=True) 2437 2438 if len(tensor_check_names) > 1 and config.enable_cpp_guard_manager: 2439 # Install tensor aliasing guard. TENSOR_MATCH guards are already 2440 # installed for cpp guard manager. 2441 install_no_tensor_aliasing_guard( 2442 builder.tensor_check_guard_managers, 2443 tensor_check_names, 2444 ["check_no_aliasing(" + ", ".join(tensor_check_names) + ")"], 2445 ) 2446 2447 aotautograd_guards: List[GuardEnvExpr] = ( 2448 self.output_graph.tracing_context.guards_context.aotautograd_guards 2449 if self.output_graph 2450 else [] 2451 ) 2452 2453 # TODO(anijain2305) - There is a duplicate logic in Dynamo to find 2454 # aliased input tensors. So most probably we don't need this here. 2455 # Revisit. 2456 for guard in aotautograd_guards: 2457 if isinstance(guard, DuplicateInputs): 2458 source_a = guard.input_source_a 2459 source_b = guard.input_source_b 2460 code_part = f"{source_a.name()} is {source_b.name()}" 2461 if config.enable_cpp_guard_manager: 2462 install_object_aliasing_guard( 2463 builder.get_guard_manager_from_source(source_a), 2464 builder.get_guard_manager_from_source(source_b), 2465 [code_part], 2466 ) 2467 add_code_part(code_part, None, config.enable_cpp_guard_manager) 2468 else: 2469 raise RuntimeError(f"Unknown GuardEnvExpr: {guard}") 2470 2471 # TODO: the "guard" here is actually just the top level SHAPE_ENV 2472 # which is useless. Get ShapeEnv to pass in more provenance. 2473 for gcl in builder.shape_env_code: 2474 for code in gcl.code_list: 2475 # Shape env guards are already added for CPP guard manager in 2476 # SHAPE_ENV implementation. 2477 add_code_part(code, gcl.guard, config.enable_cpp_guard_manager) 2478 2479 # OK, all done generating guards 2480 if structured_guard_fns: 2481 torch._logging.trace_structured( 2482 "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns] 2483 ) 2484 2485 global_state = convert_frame.initial_global_state 2486 if global_state is None: 2487 # we should only hit this case in NopTests() 2488 global_state = convert_frame.GlobalStateGuard() 2489 closure_vars = { 2490 "___check_tensors": check_tensors_fn, 2491 "___check_tensors_verbose": check_tensors_verbose_fn, 2492 "___check_global_state": global_state.check, 2493 "___check_torch_function_mode_stack": torch_function_mode_stack_check_fn, 2494 "tensor_check_names": tensor_check_names, 2495 **SYMPY_INTERP, 2496 **CLOSURE_VARS, 2497 } 2498 2499 globals_for_guard_fn = {"G": builder.scope["G"]} 2500 if config.enable_cpp_guard_manager: 2501 # Guard manager construction is complete 2502 assert self.guard_manager # to make mypy happy 2503 # TODO (anijain2305) - When enable_cpp_guard_manager is ON by 2504 # default, change the guard_fn name to be guard_manager everywhere 2505 # to avoid confusion. 2506 guard_fn = self.guard_manager 2507 # Ensure we did not miss to insert a guard in cpp guard manager. 2508 assert len(code_parts) == 0 2509 else: 2510 unique_code_parts = list(unique(code_parts)) 2511 make_guard_fn_args = ", ".join(closure_vars.keys()) 2512 guard_body, pycode = build_guard_function( 2513 unique_code_parts, make_guard_fn_args 2514 ) 2515 2516 if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1": 2517 print("GUARDS\n", guard_body) 2518 2519 out: Dict[str, Any] = {} 2520 2521 # We don't put builder.scope as the globals in exec call because 2522 # guard_fn.__globals__ becomes equal to builder.scope. This causes 2523 # guard_fn to hold a referece to f_locals sitting in builder.scope["L"] 2524 try: 2525 exec(pycode, globals_for_guard_fn, out) 2526 except SyntaxError as ex: 2527 log.exception("Failed to exec guard at line %s.\n%s", ex.lineno, pycode) 2528 raise 2529 guard_fn = out["___make_guard_fn"](*closure_vars.values()) 2530 2531 guard_fn.closure_vars = closure_vars 2532 # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both 2533 guard_fn.args = largs 2534 if config.enable_cpp_guard_manager: 2535 guard_fn.populate_code_parts_for_debugging() 2536 else: 2537 guard_fn.code_parts = code_parts 2538 guard_fn.verbose_code_parts = verbose_code_parts 2539 # Grab only G, but preserve "G" because guards access it as "G" 2540 guard_fn.global_scope = globals_for_guard_fn 2541 guard_fn.guard_fail_fn = guard_fail_fn 2542 # will be populated by a non-owning reference to CacheEntry/ExtraState 2543 # when the CacheEntry is constructed 2544 guard_fn.cache_entry = None 2545 guard_fn.extra_state = None 2546 guard_fn.no_tensor_aliasing_sources = tensor_check_names 2547 return guard_fn 2548 2549 def invalidate(self): 2550 # Some tests reveal that CheckFunctionManager has no attribute 2551 # check_fn, but this case should not be of any concern. 2552 # This case doesn't seem easy to repro. 2553 if ( 2554 hasattr(self, "check_fn") 2555 and self.check_fn is not DeletedGuardFn 2556 and (cache_entry := self.check_fn.cache_entry) is not None 2557 and (extra_state := self.check_fn.extra_state) is not None 2558 ): 2559 assert isinstance(cache_entry, CacheEntry) 2560 assert isinstance(extra_state, ExtraState) 2561 extra_state.invalidate(cache_entry) 2562 self.check_fn.cache_entry = None 2563 self.check_fn.extra_state = None 2564 self.check_fn = DeletedGuardFn 2565 2566 def id_ref(self, obj): 2567 """add a weakref, return the id""" 2568 try: 2569 if id(obj) not in self._weakrefs: 2570 # We will clear the _weakrefs dict at the end of __init__ 2571 # function, which will delete the callbacks as well. Therefore, 2572 # we are using a finalizer which is kept alive. 2573 self._weakrefs[id(obj)] = weakref.ref(obj) 2574 weakref.finalize(obj, self.invalidate) 2575 except TypeError: 2576 pass # cannot weakref bool object 2577 return id(obj) 2578 2579 def lookup_weakrefs(self, obj): 2580 """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects""" 2581 if id(obj) in self._weakrefs: 2582 return self._weakrefs[id(obj)] 2583 return None 2584 2585 2586def build_guard_function(code_parts, closure_args) -> Tuple[str, str]: 2587 from torch._inductor.utils import IndentedBuffer 2588 2589 if HAS_UNPARSE_FUNCTIONS: 2590 csepass = PyExprCSEPass() 2591 csepass.count(code_parts) 2592 2593 def replace(expr: str) -> Tuple[List[str], str]: 2594 return csepass.replace(expr) 2595 2596 else: 2597 2598 def replace(expr: str) -> Tuple[List[str], str]: 2599 return [], expr 2600 2601 # Generate the inner body of the guard function. 2602 # i.e. if-chain of the guard expressions. 2603 guard_body = IndentedBuffer() 2604 for expr in code_parts: 2605 preface, expr = replace(expr) 2606 guard_body.writelines(preface) 2607 guard_body.writeline(f"if not ({expr}):") 2608 with guard_body.indent(): 2609 guard_body.writeline("return False") 2610 2611 # Wrap the inner body into the actual guard function. 2612 guard = IndentedBuffer() 2613 guard.writeline("def guard(L):") 2614 with guard.indent(): 2615 guard.splice(guard_body) 2616 guard.writeline("return True") 2617 2618 # Wrap the whole guard function into another function 2619 # with the closure variables. 2620 make_guard_fn = IndentedBuffer() 2621 make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):") 2622 with make_guard_fn.indent(): 2623 make_guard_fn.splice(guard) 2624 make_guard_fn.writeline("return guard") 2625 2626 return guard_body.getvalue(), make_guard_fn.getvalue() 2627 2628 2629def is_recompiles_enabled(): 2630 return torch._logging._internal.log_state.is_artifact_enabled("recompiles") 2631 2632 2633def is_recompiles_verbose_enabled(): 2634 return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose") 2635 2636 2637# this will only be used if cpp guards are disabled 2638def make_torch_function_mode_stack_guard(intial_stack): 2639 types = [type(x) for x in intial_stack] 2640 from .variables.torch_function import IGNORED_MODES 2641 2642 def check_torch_function_mode_stack(): 2643 cur_stack = get_torch_function_mode_stack() 2644 if len(cur_stack) != len(types): 2645 return False 2646 2647 for ty, mode in zip(types, cur_stack): 2648 if ty in IGNORED_MODES: 2649 continue 2650 if ty != type(mode): 2651 return False 2652 2653 return True 2654 2655 return check_torch_function_mode_stack 2656 2657 2658def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): 2659 duplicate_tensors = [] 2660 global_scope = dict(guard_manager.global_scope) 2661 ids_to_source = collections.defaultdict(list) 2662 for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined] 2663 global_scope["__compile_source__"] = tensor_source 2664 tensor_id = id(eval(tensor_source, global_scope, scope)) 2665 ids_to_source[tensor_id].append(tensor_source) 2666 2667 for key in ids_to_source: 2668 if len(ids_to_source[key]) > 1: 2669 duplicate_tensors.append(f"{ids_to_source[key]}") 2670 2671 reason = ", ".join(duplicate_tensors) 2672 return [f"Duplicate tensors found: {reason}"] 2673 2674 2675def get_guard_fail_reason_helper( 2676 guard_fn: GuardFn, 2677 f_locals: Dict[str, object], 2678 compile_id: CompileId, 2679) -> str: 2680 """ 2681 Return the reason why `guard_fn` failed. 2682 Updates `guard_failures` with the generated reason. 2683 Only the first failed check of guard_fn is reported. 2684 """ 2685 scope = {"L": f_locals, "G": guard_fn.global_scope["G"]} 2686 scope.update(guard_fn.closure_vars) 2687 reasons: List[str] = [] 2688 2689 no_tensor_aliasing_check_failed = False 2690 2691 verbose_code_parts: List[str] = [] 2692 if config.enable_cpp_guard_manager: 2693 guard_manager = guard_fn 2694 guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] 2695 # For test_export_with_map_cond, the check_verbose fail even without the 2696 # C++ guard manager. We need to fix the issue to remove the comment. 2697 # assert not guard_debug_info.result 2698 if not guard_debug_info.result: 2699 verbose_code_parts = guard_debug_info.verbose_code_parts 2700 # verbose_code_parts is either the actual reason (e.g. in case of 2701 # TENSOR_MATCH) or it could be a list of verbose_code_part that we 2702 # passed to the leaf guard at construction time. If its a list, we 2703 # walk through this list and find the guard that failed. This is 2704 # very important for symbolic shape guards which are currently 2705 # installed as a lambda guard and can encompass a long list of code_parts. 2706 2707 if len(verbose_code_parts) == 1: 2708 if "Duplicate tensor found" in verbose_code_parts[0]: 2709 no_tensor_aliasing_check_failed = True 2710 else: 2711 reasons = verbose_code_parts 2712 verbose_code_parts = [] 2713 else: 2714 verbose_code_parts = guard_fn.verbose_code_parts 2715 # This is not needed for CPP guard because the verbose check is already 2716 # run in C++. 2717 scope["___check_tensors"] = scope["___check_tensors_verbose"] 2718 2719 if no_tensor_aliasing_check_failed: 2720 reasons = recompilation_reason_for_no_tensor_aliasing_guard(guard_fn, scope) 2721 else: 2722 for part in verbose_code_parts: 2723 global_scope = dict(guard_fn.global_scope) 2724 global_scope["__compile_source__"] = part 2725 with report_compile_source_on_error(): 2726 try: 2727 fail_reason = eval(part, global_scope, scope) 2728 except Exception as e: 2729 if is_recompiles_verbose_enabled(): 2730 continue 2731 else: 2732 raise 2733 # Only ___check_tensors knows how to return a fancy fail reason; 2734 # for everything else we just report the code that failed 2735 2736 if isinstance(fail_reason, bool) and not fail_reason: 2737 fail_reason = part 2738 if isinstance(fail_reason, str): 2739 reasons.append(fail_reason) 2740 if not is_recompiles_verbose_enabled(): 2741 break 2742 2743 reason_str = f"{compile_id}: " + "; ".join(reasons) 2744 return reason_str 2745 2746 2747def get_guard_fail_reason( 2748 guard_fn: GuardFn, 2749 code: types.CodeType, 2750 f_locals: Dict[str, object], 2751 compile_id: CompileId, 2752) -> str: 2753 reason_str = get_guard_fail_reason_helper(guard_fn, f_locals, compile_id) 2754 guard_failures[orig_code_map[code]].append(reason_str) 2755 2756 try: 2757 if guard_fn.guard_fail_fn is not None: 2758 guard_fn.guard_fail_fn( 2759 GuardFail(reason_str or "unknown reason", orig_code_map[code]) 2760 ) 2761 except Exception as e: 2762 log.exception( 2763 "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", 2764 ) 2765 2766 return reason_str 2767 2768 2769def get_and_maybe_log_recompilation_reason( 2770 cache_entry, frame: types.FrameType 2771) -> List[str]: 2772 """ 2773 Return the list of guard failure reasons using cache_entry. 2774 Logs the recompilation reason if `recompiles` logging is enabled. 2775 Raises a RecompileError if `config.error_on_recompile` is enabled. 2776 """ 2777 reasons = [] 2778 while cache_entry is not None: 2779 reason = get_guard_fail_reason( 2780 cache_entry.check_fn, 2781 cache_entry.code, 2782 frame.f_locals, 2783 cache_entry.compile_id, 2784 ) 2785 if reason: 2786 reasons.append(reason) 2787 cache_entry = cache_entry.next 2788 2789 code = frame.f_code 2790 2791 # at least one of "recompiles" or "recompiles_verbose" is enabled 2792 do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled() 2793 2794 if do_recompiles_log or config.error_on_recompile: 2795 if is_recompiles_verbose_enabled(): 2796 failures = "\n\n".join( 2797 f"guard {i} failures:\n" + textwrap.indent(reason, "- ") 2798 for i, reason in enumerate(reasons) 2799 ) 2800 else: 2801 failures = textwrap.indent("\n".join(reasons), "- ") 2802 guard_failure_details = ( 2803 f"triggered by the following guard failure(s):\n{failures}" 2804 ) 2805 message = ( 2806 f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n" 2807 f"{textwrap.indent(guard_failure_details, ' ')}" 2808 ) 2809 if do_recompiles_log: 2810 if is_recompiles_verbose_enabled(): 2811 recompiles_verbose_log.debug(message) 2812 else: 2813 recompiles_log.debug(message) 2814 if config.error_on_recompile: 2815 raise exc.RecompileError(message) 2816 2817 torch._logging.trace_structured( 2818 "artifact", 2819 metadata_fn=lambda: { 2820 "name": "recompile_reasons", 2821 "encoding": "json", 2822 }, 2823 payload_fn=lambda: reasons, 2824 ) 2825 2826 return reasons 2827 2828 2829def guard_error_hook( 2830 guard_fn: GuardFn, 2831 code: types.CodeType, 2832 f_locals: Dict[str, object], 2833 index: int, 2834 last: bool, 2835): 2836 print( 2837 f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" 2838 ) 2839 print("lambda " + ", ".join(guard_fn.args) + ":") 2840 print(" ", " and\n ".join(guard_fn.code_parts)) 2841 2842 if config.enable_cpp_guard_manager: 2843 print(guard_fn) 2844 2845 local_scope = {"L": f_locals, **guard_fn.closure_vars} 2846 for guard in guard_fn.code_parts: 2847 try: 2848 eval(guard, guard_fn.global_scope, local_scope) 2849 except: # noqa: B001,E722 2850 print(f"Malformed guard:\n{guard}") 2851 2852 2853set_guard_error_hook(guard_error_hook) 2854 2855 2856def unique(seq): 2857 seen = set() 2858 for x in seq: 2859 if x not in seen: 2860 yield x 2861 seen.add(x) 2862 2863 2864def make_dupe_guard(obj_source, dupe_source): 2865 # Note - we may end up in a situation where we invoke something like 2866 # def fn(x, y) 2867 # with fn(x, x) 2868 # Prior to the addition of tracking to all relevant objects, we would handle this just fine by 2869 # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However, 2870 # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here - 2871 # In the fn(x, x) example call above look like a graph with a single input. 2872 # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard. 2873 2874 # Note - we may not have a source, that is fine, it just means we had an object that is safe to have 2875 # leave unsourced - like a local list created and discharged entirely within a local scope. 2876 if dupe_source and dupe_source != obj_source: 2877 ser_source_is_local = is_from_local_source(dupe_source) 2878 source_is_local = is_from_local_source(obj_source) 2879 if is_from_flatten_script_object_source( 2880 dupe_source 2881 ) or is_from_flatten_script_object_source(obj_source): 2882 raise exc.UnsafeScriptObjectError( 2883 f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported." 2884 f" Please do a clone for corresponding input." 2885 ) 2886 2887 # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently 2888 # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here, 2889 # so maybe we should do this refactor before we land this... 2890 # TODO(voz): Combine local and global guard builders. 2891 if ser_source_is_local == source_is_local: 2892 # Note - this is a little aggressive - these being duplicate input does not always matter. 2893 # However, this should always be a sound guard to add here. 2894 return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source) 2895 return None 2896 2897 2898def install_guard(*guards, skip=0): 2899 """ 2900 Add dynamo guards to the current tracing context. 2901 2902 Args: 2903 guards: guard(s) to add 2904 skip: number of stack frames to ignore for debug stack trace 2905 """ 2906 from torch._guards import TracingContext 2907 2908 collect_debug_stack = guards_log.isEnabledFor( 2909 logging.DEBUG 2910 ) or verbose_guards_log.isEnabledFor(logging.DEBUG) 2911 add = TracingContext.get().guards_context.dynamo_guards.add 2912 for guard in guards: 2913 assert isinstance(guard, Guard) 2914 add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1) 2915