xref: /aosp_15_r20/external/pytorch/torch/_dynamo/guards.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import ast
5import builtins
6import collections
7import dataclasses
8import enum
9import functools
10import importlib
11import inspect
12import itertools
13import logging
14import math
15import os
16import re
17import sys
18import textwrap
19import types
20import weakref
21from contextlib import contextmanager
22from copy import deepcopy
23from inspect import currentframe, getframeinfo
24from typing import (
25    Any,
26    Callable,
27    Dict,
28    List,
29    Optional,
30    Set,
31    Tuple,
32    Type,
33    TYPE_CHECKING,
34    Union,
35)
36from weakref import ReferenceType
37
38import torch
39import torch.utils._device
40from torch._C._dynamo.guards import (
41    check_obj_id,
42    check_type_id,
43    dict_version,
44    DictGuardManager,
45    install_no_tensor_aliasing_guard,
46    install_object_aliasing_guard,
47    RootGuardManager,
48    TensorGuards,
49)
50from torch._dynamo.source import (
51    is_from_flatten_script_object_source,
52    is_from_local_source,
53    is_from_optimizer_source,
54    TensorProperty,
55    TensorPropertySource,
56)
57from torch._guards import (
58    CompileContext,
59    CompileId,
60    DuplicateInputs,
61    Guard,
62    GuardBuilderBase,
63    GuardEnvExpr,
64    GuardSource,
65    Source,
66)
67from torch._logging import structured
68from torch._utils_internal import justknobs_check
69from torch.fx.experimental.symbolic_shapes import (
70    EqualityConstraint,
71    is_symbolic,
72    SYMPY_INTERP,
73)
74from torch.utils._traceback import format_frame, report_compile_source_on_error
75from torch.utils.weak import TensorWeakRef
76
77from . import config, convert_frame, exc, mutation_guard
78from .eval_frame import set_guard_error_hook
79from .source import (
80    AttrProxySource,
81    AttrSource,
82    ChainedSource,
83    ConstDictKeySource,
84    DefaultsSource,
85    FlattenScriptObjectSource,
86    FSDPNNModuleSource,
87    GetItemSource,
88    GlobalSource,
89    GlobalStateSource,
90    GlobalWeakRefSource,
91    GradSource,
92    LocalSource,
93    NNModuleSource,
94    NumpyTensorSource,
95    ODictGetItemSource,
96    OptimizerSource,
97    ScriptObjectQualifiedNameSource,
98    ShapeEnvSource,
99    SubclassAttrListSource,
100    TupleIteratorGetItemSource,
101    TypeSource,
102    UnspecializedBuiltinNNModuleSource,
103    UnspecializedNNModuleSource,
104    UnspecializedParamBufferSource,
105    WeakRefCallSource,
106)
107from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn  # noqa: F401
108from .utils import (
109    common_constant_types,
110    dict_keys_repr,
111    get_custom_getattr,
112    get_torch_function_mode_stack,
113    guard_failures,
114    istype,
115    key_is_id,
116    key_to_id,
117    orig_code_map,
118    tensor_always_has_static_shape,
119    tuple_iterator_getitem,
120    tuple_iterator_len,
121    unpatched_nn_module_getattr,
122    verify_guard_fn_signature,
123)
124
125
126try:
127    import numpy as np
128except ModuleNotFoundError:
129    np = None  # type: ignore[assignment]
130
131
132if TYPE_CHECKING:
133    from sympy import Symbol
134
135
136log = logging.getLogger(__name__)
137guards_log = torch._logging.getArtifactLogger(__name__, "guards")
138recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles")
139recompiles_verbose_log = torch._logging.getArtifactLogger(
140    __name__, "recompiles_verbose"
141)
142verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
143
144
145class GuardManager:
146    """
147    A helper class that contains the root guard manager. An instance of this
148    class is stored in the Dynamo cache entry, so that the cache entry can
149    access the RootGuardManager stored in the "root" attribute and directly call
150    the check_nopybind from C++.
151    """
152
153    def __init__(self):
154        self.root = RootGuardManager()
155
156        self.closure_vars = None
157        self.args = None
158        self.code_parts = []
159        self.verbose_code_parts = None
160        self.global_scope = None
161        self.guard_fail_fn = None
162        self.cache_entry = None
163        self.extra_state = None
164        self.id_matched_objs = None
165        self.no_tensor_aliasing_sources = []
166
167        self.print_no_tensor_aliasing_guard = True
168
169    @contextmanager
170    def _preserve_print_no_tensor_aliasing_flag(self):
171        self.print_no_tensor_aliasing_guard = True
172        try:
173            yield
174        finally:
175            self.print_no_tensor_aliasing_guard = True
176
177    def get_guard_lines(self, guard):
178        guard_name = guard.__class__.__name__
179        parts = guard.verbose_code_parts()
180        parts = [guard_name + ": " + part for part in parts]
181        return parts
182
183    def get_manager_line(self, guard_manager, accessor_str=None):
184        source = guard_manager.get_source()
185        t = guard_manager.__class__.__name__
186        s = t + ": source=" + source
187        if accessor_str:
188            s += ", " + accessor_str
189        return s
190
191    def construct_dict_manager_string(self, mgr, body):
192        for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()):
193            body.writeline(f"KeyValueManager pair at index={idx}")
194            with body.indent():
195                if key_mgr:
196                    body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}")
197                    self.construct_manager_string(key_mgr, body)
198
199                if val_mgr:
200                    body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}")
201                    self.construct_manager_string(val_mgr, body)
202
203    def construct_manager_string(self, mgr, body):
204        with body.indent():
205            for guard in mgr.get_leaf_guards():
206                if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING):  # type: ignore[attr-defined]
207                    if self.print_no_tensor_aliasing_guard:
208                        self.print_no_tensor_aliasing_guard = False
209                        body.writelines(self.get_guard_lines(guard))
210                    else:
211                        body.writelines(
212                            [
213                                guard.__class__.__name__,
214                            ]
215                        )
216                else:
217                    body.writelines(self.get_guard_lines(guard))
218
219            # This works for both DictGuardManager and SubclassedDictGuardManager
220            if isinstance(mgr, DictGuardManager):
221                self.construct_dict_manager_string(mgr, body)
222
223            # General case of GuardManager/RootGuardManager
224            for accessor, child_mgr in zip(
225                mgr.get_accessors(), mgr.get_child_managers()
226            ):
227                body.writeline(
228                    self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}")
229                )
230                self.construct_manager_string(child_mgr, body)
231
232    def __str__(self):
233        from torch._inductor.utils import IndentedBuffer
234
235        class IndentedBufferWithPrefix(IndentedBuffer):
236            def prefix(self):
237                return "| " * (self._indent * self.tabwidth)
238
239            def writeline(self, line, skip_prefix=False):
240                if skip_prefix:
241                    super().writeline(line)
242                else:
243                    super().writeline("+- " + line)
244
245        with self._preserve_print_no_tensor_aliasing_flag():
246            body = IndentedBufferWithPrefix()
247            body.tabwidth = 1
248            body.writeline("", skip_prefix=True)
249            body.writeline("TREE_GUARD_MANAGER:", skip_prefix=True)
250            body.writeline("RootGuardManager")
251            self.construct_manager_string(self.root, body)
252            for guard in self.root.get_epilogue_lambda_guards():
253                body.writelines(self.get_guard_lines(guard))
254            return body.getvalue()
255
256    def check(self, x):
257        # Only needed for debugging purposes.
258        return self.root.check(x)
259
260    def check_verbose(self, x):
261        # Only needed for debugging purposes.
262        return self.root.check_verbose(x)
263
264    def populate_code_parts_for_debugging(self):
265        # This should be called when the guard manager is fully populated
266        tensor_aliasing_guard_seen = False
267
268        def get_code_parts(leaf_guard):
269            code_parts = []
270            for verbose_code_part in leaf_guard.verbose_code_parts():
271                code_part = verbose_code_part.split("#")[0].rstrip()
272                code_parts.append(code_part)
273            return code_parts
274
275        def visit(mgr):
276            nonlocal tensor_aliasing_guard_seen
277            for guard in mgr.get_leaf_guards():
278                if isinstance(guard, torch._C._dynamo.guards.NO_TENSOR_ALIASING):  # type: ignore[attr-defined]
279                    if not tensor_aliasing_guard_seen:
280                        self.code_parts.extend(get_code_parts(guard))
281                        tensor_aliasing_guard_seen = True
282                else:
283                    self.code_parts.extend(get_code_parts(guard))
284
285            for child_mgr in mgr.get_child_managers():
286                visit(child_mgr)
287
288        visit(self.root)
289
290
291def from_numpy(a):
292    # If not numpy array, piggy back on e.g. tensor guards to check type
293    return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a
294
295
296# For user stack printing
297@functools.lru_cache(None)
298def uninteresting_files():
299    import torch._dynamo.external_utils
300
301    mods = [
302        torch._dynamo.external_utils,
303    ]
304    return {inspect.getfile(m) for m in mods}
305
306
307CLOSURE_VARS = {
308    "___check_type_id": check_type_id,
309    "___check_obj_id": check_obj_id,
310    "___odict_getitem": collections.OrderedDict.__getitem__,
311    "___key_to_id": key_to_id,
312    "___dict_version": dict_version,
313    "___dict_contains": lambda a, b: a in b,
314    "___tuple_iterator_len": tuple_iterator_len,
315    "___tuple_iterator_getitem": tuple_iterator_getitem,
316    "__math_isnan": math.isnan,
317    "__numpy_isnan": None if np is None else np.isnan,
318    "inf": float("inf"),
319    "__load_module": importlib.import_module,
320    "utils_device": torch.utils._device,
321    "device": torch.device,
322    "___from_numpy": from_numpy,
323    "___as_tensor": torch.as_tensor,
324    "torch": torch,
325    "inspect": inspect,
326}
327
328if sys.version_info[:2] <= (3, 8):
329    # [Note: Python Version <= 3.8]
330    # This branch should be dropped when we drop support for Python 3.8.
331    # Reason: 'ast.unparse' function was introduced in Python 3.9.
332
333    try:
334        import astunparse  # type: ignore[import]
335
336        def _ast_unparse(node: ast.AST) -> str:
337            return astunparse.unparse(node).replace("\n", "")
338
339        HAS_UNPARSE_FUNCTIONS = True
340    except ImportError:
341        HAS_UNPARSE_FUNCTIONS = False
342else:
343    HAS_UNPARSE_FUNCTIONS = True
344
345    def _ast_unparse(node: ast.AST) -> str:
346        return ast.unparse(node).replace("\n", "")
347
348
349def strip_function_call(name):
350    """
351    "___odict_getitem(a, 1)" => "a"
352    "a.layers[slice(2)][0]._xyz" ==> "a"
353    "getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a"
354    "getattr(getattr(a.x[3], '0'), '3')" ==> "a"
355    "a.layers[slice(None, -1, None)][0]._xyz" ==> "a"
356    """
357    # recursively find valid object name in function
358    valid_name = re.compile("[A-Za-z_].*")
359    curr = ""
360    for char in name:
361        if char in " (":
362            curr = ""
363        elif char in "),[]":
364            if curr and curr != "None" and valid_name.match(curr):
365                return strip_function_call(curr)
366        else:
367            curr += char
368
369    return strip_getattr_getitem(name)
370
371
372def strip_getattr_getitem(name):
373    """
374    "a[1]" => "a"
375    "a.foo" => "a"
376    """
377    return re.split(r"[.\[]", name)[0]
378
379
380def get_verbose_code_part(code_part: str, guard: Guard) -> str:
381    extra = ""
382    if guard.user_stack:
383        for fs in reversed(guard.user_stack):
384            if fs.filename not in uninteresting_files():
385                extra = f"  # {format_frame(fs, line=True)}"
386                break
387    elif guard.stack:
388        extra = f"  # {format_frame(guard.stack.summary()[-1])}"
389
390    return f"{code_part:<60}{extra}"
391
392
393def get_verbose_code_parts(
394    code_parts: Union[str | List[str]], guard: Guard
395) -> List[str]:
396    if not isinstance(code_parts, list):
397        code_parts = [code_parts]
398    return [get_verbose_code_part(code_part, guard) for code_part in code_parts]
399
400
401def convert_to_concrete_values(size_or_stride):
402    converted: List[Optional[int]] = []
403    for dim in size_or_stride:
404        if not is_symbolic(dim):
405            converted.append(dim)
406        else:
407            assert isinstance(dim, torch.SymInt)
408            converted.append(dim.node.maybe_as_int())
409    return converted
410
411
412def get_tensor_guard_code_part(value, name, sizes, strides):
413    pytype = type(value)
414    dispatch_key = (
415        torch._C._dispatch_keys(value) | torch._C._dispatch_tls_local_include_set()
416    ) - torch._C._dispatch_tls_local_exclude_set()
417    dtype = value.dtype
418    device_index = value.device.index
419    requires_grad = value.requires_grad
420    guard_str = (
421        f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, "
422        f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})"
423    )
424    return guard_str
425
426
427def get_key_index(dct, key):
428    return list(dct.keys()).index(key)
429
430
431def get_key_index_source(source, index):
432    return f"list({source}.keys())[{index}]"
433
434
435@dataclasses.dataclass(frozen=True)
436class NNModuleAttrAccessorInfo:
437    # Represents where is the attr name is present in the nn module attribute
438    # access
439
440    # Tells that the attribute can be accessed via __dict__
441    present_in_generic_dict: bool = False
442
443    # Either the actual name or _parameters/_buffers/_modules
444    l1_key: Optional[str] = None
445
446    # Actual paramter/buffer/submodule name
447    l2_key: Optional[str] = None
448
449
450def getitem_on_dict_manager(
451    source, base_guard_manager, base_example_value, example_value, guard_manager_enum
452):
453    base_source_name = source.base.name()
454    source_name = source.name()
455    if isinstance(source.index, ConstDictKeySource):
456        index = source.index.index
457    else:
458        assert isinstance(base_example_value, dict)
459        index = get_key_index(base_example_value, source.index)
460
461    key_source = get_key_index_source(base_source_name, index)
462    key_example_value = list(base_example_value.keys())[index]
463    if isinstance(key_example_value, (int, str)):
464        value_source = f"{base_source_name}[{key_example_value!r}]"
465    else:
466        value_source = f"{base_source_name}[{key_source}]"
467    if not isinstance(source.index, ConstDictKeySource):
468        # We have to insert a key manager guard here
469        # TODO - source debug string is probably wrong here.
470        base_guard_manager.get_key_manager(
471            index=index,
472            source=key_source,
473            example_value=source.index,
474            guard_manager_enum=GuardManagerType.GUARD_MANAGER,
475        ).add_equals_match_guard(
476            source.index, [f"{key_source} == {key_example_value!r}"]
477        )
478
479    return base_guard_manager.get_value_manager(
480        index=index,
481        source=value_source,
482        example_value=example_value,
483        guard_manager_enum=guard_manager_enum,
484    )
485
486
487def match_on_id_for_tensor(guard):
488    source = guard.originating_source
489    return source.is_dict_key() and not isinstance(source, GradSource)
490
491
492# The ready to eval generated code (possibly multiple parts) for a guard, plus
493# the original guard object that created it for provenance
494@dataclasses.dataclass
495class GuardCodeList:
496    code_list: List[str]
497    guard: Guard
498
499
500class GuardManagerType(enum.Enum):
501    GUARD_MANAGER = 1
502    DICT_GUARD_MANAGER = 2
503    DICT_SUBCLASS_GUARD_MANAGER = 3
504
505
506class GuardBuilder(GuardBuilderBase):
507    def __init__(
508        self,
509        id_ref: Callable[[Any], str],
510        source_ref: Callable[[Source], str],
511        lookup_weakrefs: Callable[[object], ReferenceType[object]],
512        local_scope: Dict[str, object],
513        global_scope: Dict[str, object],
514        guard_manager: Optional[GuardManager],
515        check_fn_manager: CheckFunctionManager,
516    ):
517        self.id_ref = id_ref
518        self.source_ref = source_ref
519        self.lookup_weakrefs = lookup_weakrefs
520        self.scope: Dict[str, Dict[str, object]] = {"L": local_scope, "G": global_scope}
521        self.scope["__builtins__"] = builtins.__dict__.copy()
522        for (
523            name,
524            package_module,
525        ) in torch.package.package_importer._package_imported_modules.items():
526            name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
527            # Write the package module into the scope so that we can import it
528            self.scope["__builtins__"][name] = package_module
529            # Write the demangled name to the scope so that we can use it
530            self.scope[name] = package_module
531        self.guard_manager = guard_manager
532
533        self.argnames: List[str] = []
534        # Code is python expression strings generated for each guard
535        self.code: List[GuardCodeList] = []
536        # shape_env_code is only used by builder and is used for
537        # shape env code.  This exists only because we need to make sure
538        # shape env guards get run after tensor match guards (since the
539        # tensor match guards make sure we actually have tensors)
540        self.shape_env_code: List[GuardCodeList] = []
541
542        # [Note - On Eager Tensor Guards]
543        # Most of the time, we generate Python code in a guard to directly
544        # check various properties.  However, tensors are a bit special;
545        # it is too slow to check their properties one-by-one in Python.
546        # Instead, there is a C++ function TensorGuards.check which takes
547        # all of the tensor arguments and checks them all against compile-time
548        # examples entirely in C++.  Thus, every time we process a
549        # TENSOR_MATCH guard, we just add another entry to
550        # tensor_check_names/tensor_check_examples, saying "for this local,
551        # check it against this example", and it all ends up getting
552        # swept up into a single call to ___check_tensors.  Invariant:
553        # len(tensor_check_names) == len(tensor_check_examples).
554        # TODO: something here
555        self.tensor_check_names: List[str] = []
556        self.tensor_check_examples: List[torch.Tensor] = []
557        self.tensor_check_guards: List[Guard] = []
558        self.tensor_check_guard_managers: List[GuardManager] = []
559
560        self.check_fn_manager: CheckFunctionManager = check_fn_manager
561
562        # Collect the ids of dicts which need key order guarding. source_name is
563        # not sufficient because for nn modules, we can have different sources
564        # to access the same object - self._module["param"] is same as
565        # self.param.
566        self.key_order_guarded_dict_ids = set()
567        for source_name in self.check_fn_manager.output_graph.guard_on_key_order:
568            self.key_order_guarded_dict_ids.add(id(self.get(source_name)))
569
570        # Keep track of weak references of objects with ID_MATCH guard. This
571        # info is stored alongside optimized_code and check_fn and is used to
572        # limit the number of cache entries with same ID_MATCH'd object.
573        self.id_matched_objs: Dict[str, ReferenceType[object]] = {}
574
575        # Save the guard managers to avoid repeatedly traversing sources.
576        self._cached_guard_managers: Dict[
577            str, torch._C._dynamo.guards.GuardManager
578        ] = {}
579
580        self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set()
581
582    def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
583        dict_mgr = self.get_guard_manager(guard)
584        if isinstance(dict_mgr, DictGuardManager):
585            raise NotImplementedError(
586                "Not expecting a DictGuardManager. Seems like Dynamo incorrectly "
587                f"added the dict to tx.output.guard_on_key_order for {guard.name}"
588            )
589
590        # Iterate over the dicts and install a dict_getitem_manager.
591        dict_source = guard.originating_source.name()
592        for key in example_value.keys():
593            value = example_value[key]
594            value_source = GetItemSource(guard.originating_source, index=key)
595            guard_manager_enum = self.get_guard_manager_type(
596                value_source, example_value
597            )
598            dict_mgr.dict_getitem_manager(
599                key=key,
600                source=f"{dict_source}[{key!r}]",
601                example_value=value,
602                guard_manager_enum=guard_manager_enum,
603            )
604
605    def guard_on_dict_keys_and_order(self, value, guard):
606        # Add key managers for the DictGuardManager. Then add either an
607        # ID_MATCH or EQUALS_MATCH guard on the key.
608        dict_mgr = self.get_guard_manager(guard)
609        if not isinstance(dict_mgr, DictGuardManager):
610            raise NotImplementedError(
611                "Expecting a DictGuardManager. Seems like Dynamo forgot "
612                f"to set the right guard manager enum for {guard.name}"
613            )
614        assert isinstance(dict_mgr, DictGuardManager)
615
616        for idx, key in enumerate(value.keys()):
617            key_source = get_key_index_source(guard.name, idx)
618            key_manager = dict_mgr.get_key_manager(
619                index=idx,
620                source=key_source,
621                example_value=key,
622                guard_manager_enum=GuardManagerType.GUARD_MANAGER,
623            )
624            if key_is_id(key):
625                # Install ID_MATCH guard
626                id_val = self.id_ref(key)
627                key_manager.add_id_match_guard(
628                    id_val,
629                    get_verbose_code_parts(
630                        f"__check_obj_id({key_source}, {id_val})", guard
631                    ),
632                )
633            else:
634                # Install EQUALS_MATCH guard
635                key_manager.add_equals_match_guard(
636                    key, get_verbose_code_parts(f"{key_source} == {key!r}", guard)
637                )
638
639    def getattr_on_nn_module(
640        self,
641        source,
642        base_guard_manager,
643        base_example_value,
644        example_value,
645        base_source_name,
646        source_name,
647        guard_manager_enum,
648    ):
649        """
650        This tries to avoid calling the expensive nn module custom getattr method by
651        checking if the attribute is accessible via __dict__. For attributes that
652        are not accessible via __dict__ (like descriptors), we fallback to
653        PyObject_GetAttr.
654
655        There are two cases that we optimize for
656        1) attributes present directly in __dict__, e.g training.
657        2) parameters/buffers/modules - they can be accessed via _parameters,
658        _buffers, _modules keys in __dict__. For example, mod.linear can be
659        accessed as mod.__dict__["_parameters"]["linear"]
660
661        The most common and expensive case for nn module guards is of type
662        mod.submod1.submod2.submod3.training. We avoid the python getattr of nn
663        modules by going through the __dict__.
664        """
665
666        def getitem_on_dict_mgr(
667            mgr, key, source_name, base_example_value, example_value, guard_manager_enum
668        ):
669            if isinstance(mgr, DictGuardManager):
670                # Case where the user code relies on key order, e.g.,
671                # named_parameters
672                index = get_key_index(base_example_value, key)
673
674                # Install the key manager and add equals match guard
675                key_source = f"list({source_name}.keys())[{index!r}]"
676                mgr.get_key_manager(
677                    index=index,
678                    source=key_source,
679                    example_value=key,
680                    guard_manager_enum=GuardManagerType.GUARD_MANAGER,
681                ).add_equals_match_guard(key, [f"{key_source} == {key!r}"])
682
683                # Install the value manager
684                return mgr.get_value_manager(
685                    index=index,
686                    source=source_name,
687                    example_value=example_value,
688                    guard_manager_enum=guard_manager_enum,
689                )
690            else:
691                return mgr.dict_getitem_manager(
692                    key=key,
693                    source=source_name,
694                    example_value=example_value,
695                    guard_manager_enum=guard_manager_enum,
696                )
697
698        attr_name = source.member
699        mod_dict = base_example_value.__dict__
700
701        all_class_attribute_names: Set[str] = set()
702        for x in inspect.getmro(base_example_value.__class__):
703            all_class_attribute_names.update(x.__dict__.keys())
704
705        accessor_info = NNModuleAttrAccessorInfo(False, None, None)
706
707        if attr_name in mod_dict:
708            accessor_info = NNModuleAttrAccessorInfo(True, attr_name, None)
709        elif "_parameters" in mod_dict and attr_name in mod_dict["_parameters"]:
710            accessor_info = NNModuleAttrAccessorInfo(True, "_parameters", attr_name)
711        elif "_buffers" in mod_dict and attr_name in mod_dict["_buffers"]:
712            accessor_info = NNModuleAttrAccessorInfo(True, "_buffers", attr_name)
713        elif (
714            attr_name not in all_class_attribute_names
715            and "_modules" in mod_dict
716            and attr_name in mod_dict["_modules"]
717        ):
718            # Check test_attr_precedence test - instance attributes always take precedence unless its an nn.Module.
719            accessor_info = NNModuleAttrAccessorInfo(True, "_modules", attr_name)
720
721        if not accessor_info.present_in_generic_dict:
722            # The attribute can be accessed by __getattribute__ call, so rely on
723            # PyObject_GetAttr
724            return base_guard_manager.getattr_manager(
725                attr=source.member,
726                source=source_name,
727                example_value=example_value,
728                guard_manager_enum=guard_manager_enum,
729            )
730        else:
731            assert accessor_info.l1_key
732            l1_key = accessor_info.l1_key
733            l2_key = accessor_info.l2_key
734
735            # Set source strings for debug info
736            mod_dict_source = f"{base_source_name}.__dict__"
737            l1_source_name = l2_source_name = None
738            l1_value = l2_value = None
739            l1_guard_manager_enum = l2_guard_manager_enum = None
740            if l2_key:
741                l1_source = AttrSource(source.base, l1_key)
742                l1_source_name = l1_source.name()
743                l1_value = mod_dict[l1_key]
744                # do not guard on key order for _parameters etc unless the user code
745                # actually needs the key order (e.g. calling named_parameters)
746                l1_guard_manager_enum = self.get_guard_manager_type(l1_source, l1_value)
747
748                l2_source_name = source_name
749                l2_value = example_value
750                l2_guard_manager_enum = self.get_guard_manager_type(
751                    source, example_value
752                )
753            else:
754                l1_source_name = source_name
755                l1_value = example_value
756                l1_guard_manager_enum = self.get_guard_manager_type(
757                    source, example_value
758                )
759
760            # Get __dict__ accessor. No need to guard on dict key order, so use base
761            # Guard Manager
762            mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager(
763                source=mod_dict_source,
764                example_value=mod_dict,
765                guard_manager_enum=GuardManagerType.GUARD_MANAGER,
766            )
767
768            l1_mgr = getitem_on_dict_mgr(
769                mgr=mod_generic_dict_manager,
770                key=l1_key,
771                source_name=l1_source_name,
772                base_example_value=mod_dict,
773                example_value=l1_value,
774                guard_manager_enum=l1_guard_manager_enum,
775            )
776
777            if l2_key:
778                return getitem_on_dict_mgr(
779                    mgr=l1_mgr,
780                    key=l2_key,
781                    source_name=l2_source_name,
782                    base_example_value=l1_value,
783                    example_value=l2_value,
784                    guard_manager_enum=l2_guard_manager_enum,
785                )
786            return l1_mgr
787
788    def requires_key_order_guarding(self, source):
789        source_name = source.name()
790        if source_name == "":
791            return False
792        obj_id = id(self.get(source_name))
793        return obj_id in self.key_order_guarded_dict_ids
794
795    def get_guard_manager_type(self, source, example_value):
796        guard_manager_enum = GuardManagerType.GUARD_MANAGER
797        if self.requires_key_order_guarding(source):
798            assert isinstance(example_value, dict)
799            # If keys method is not overriden, we can use PyDict_Next to get key
800            # orderings. Read more in guards.cpp
801            if type(example_value).keys is type({}).keys:
802                guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
803            else:
804                guard_manager_enum = GuardManagerType.DICT_SUBCLASS_GUARD_MANAGER
805        return guard_manager_enum
806
807    def manager_guards_on_keys(self, mgr_enum):
808        return (
809            mgr_enum == GuardManagerType.DICT_GUARD_MANAGER
810            or mgr_enum == GuardManagerType.DICT_SUBCLASS_GUARD_MANAGER
811        )
812
813    def get_global_guard_manager(self):
814        assert self.guard_manager  # to make mypy happy
815        return self.guard_manager.root.globals_dict_manager(
816            f_globals=self.scope["G"],
817            source="G",
818            example_value=self.scope["G"],
819            guard_manager_enum=GuardManagerType.GUARD_MANAGER,
820        )
821
822    def get_guard_manager_from_source(self, source):
823        assert self.guard_manager  # to make mypy happy
824        root_guard_manager = self.guard_manager.root
825
826        example_value = None
827        source_name = source.name()
828
829        if source_name != "" and source_name in self._cached_guard_managers:
830            return self._cached_guard_managers[source_name]
831
832        if source_name != "":
833            example_value = self.get(source_name)
834
835        guard_manager_enum = self.get_guard_manager_type(source, example_value)
836
837        # Get base manager related information
838        base_source_name = None
839        base_example_value = None
840        base_guard_manager = None
841        base_guard_manager_enum = GuardManagerType.GUARD_MANAGER
842        if isinstance(source, ChainedSource):
843            base_source_name = source.base.name()
844            base_example_value = self.get(base_source_name)
845            base_guard_manager = self.get_guard_manager_from_source(source.base)
846            base_guard_manager_enum = self.get_guard_manager_type(
847                source.base, base_example_value
848            )
849
850        # Use istype instead of isinstance to check for exact type of source.
851        if istype(source, LocalSource):
852            # RootGuardManager accepts a dict but still its not a
853            # DictGuardManager because we will eventually move to
854            # fastlocals.
855            out = root_guard_manager.dict_getitem_manager(
856                key=source.local_name,
857                source=source_name,
858                example_value=example_value,
859                guard_manager_enum=guard_manager_enum,
860            )
861        elif istype(source, GlobalSource):
862            # Global manager accepts a dict but it is not a DictGuardManager
863            # because globals dict is big and we typically guard on a very
864            # selected items on globals.
865            out = self.get_global_guard_manager().dict_getitem_manager(
866                key=source.global_name,
867                source=source_name,
868                example_value=example_value,
869                guard_manager_enum=guard_manager_enum,
870            )
871        elif istype(source, GlobalWeakRefSource):
872            out = self.get_global_guard_manager().global_weakref_manager(
873                global_name=source.global_name,
874                source=source_name,
875                example_value=example_value,
876                guard_manager_enum=guard_manager_enum,
877            )
878        elif istype(source, GlobalStateSource):
879            # Don't do anything here. We guard on global state completely in
880            # C++. So just return the root mgr.
881            return root_guard_manager
882        elif istype(source, ShapeEnvSource):
883            return root_guard_manager
884        elif istype(source, TypeSource):
885            assert base_guard_manager  # to make mypy happy
886            out = base_guard_manager.type_manager(
887                source=source_name,
888                example_value=example_value,
889                guard_manager_enum=guard_manager_enum,
890            )
891        elif istype(
892            source,
893            (
894                OptimizerSource,
895                NNModuleSource,
896                UnspecializedNNModuleSource,
897                UnspecializedBuiltinNNModuleSource,
898                FSDPNNModuleSource,
899            ),
900        ):
901            assert base_guard_manager  # to make mypy happy
902            out = base_guard_manager
903        elif istype(source, GradSource):
904            assert base_guard_manager  # to make mypy happy
905            out = base_guard_manager.grad_manager(
906                source=source_name,
907                example_value=example_value,
908                guard_manager_enum=guard_manager_enum,
909            )
910        elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
911            assert base_guard_manager  # to make mypy happy
912
913            if (
914                isinstance(base_example_value, torch.nn.Module)
915                and get_custom_getattr(base_example_value)
916                is unpatched_nn_module_getattr
917            ):
918                out = self.getattr_on_nn_module(
919                    source,
920                    base_guard_manager,
921                    base_example_value,
922                    example_value,
923                    base_source_name,
924                    source_name,
925                    guard_manager_enum,
926                )
927            else:
928                out = base_guard_manager.getattr_manager(
929                    attr=source.member,
930                    source=source_name,
931                    example_value=example_value,
932                    guard_manager_enum=guard_manager_enum,
933                )
934        elif istype(source, GetItemSource):
935            assert base_guard_manager  # to make mypy happy
936            if isinstance(base_example_value, (dict, collections.OrderedDict)):
937                # TODO(anijain2305) - Consider isolating GetItemSource and
938                # DictGetItemSource (or maybe use ODictGetItemSource for
939                # dicts) so that GetItemSource is only for non dict objects.
940                if isinstance(base_guard_manager, DictGuardManager):
941                    assert self.manager_guards_on_keys(base_guard_manager_enum)
942                    out = getitem_on_dict_manager(
943                        source,
944                        base_guard_manager,
945                        base_example_value,
946                        example_value,
947                        guard_manager_enum,
948                    )
949                else:
950                    if isinstance(source.index, ConstDictKeySource):
951                        raise RuntimeError(
952                            "Expecting clean index here. Likely Dynamo forgot to mark"
953                            " a dict as guard_on_key_order"
954                        )
955                    out = base_guard_manager.dict_getitem_manager(
956                        key=source.index,
957                        source=source_name,
958                        example_value=example_value,
959                        guard_manager_enum=guard_manager_enum,
960                    )
961            elif isinstance(base_example_value, list) and not source.index_is_slice:
962                out = base_guard_manager.list_getitem_manager(
963                    key=source.index,
964                    source=source_name,
965                    example_value=example_value,
966                    guard_manager_enum=guard_manager_enum,
967                )
968            elif isinstance(base_example_value, tuple) and not source.index_is_slice:
969                out = base_guard_manager.tuple_getitem_manager(
970                    key=source.index,
971                    source=source_name,
972                    example_value=example_value,
973                    guard_manager_enum=guard_manager_enum,
974                )
975            else:
976                index = source.index
977                if source.index_is_slice:
978                    index = source.unpack_slice()
979                out = base_guard_manager.getitem_manager(
980                    key=index,
981                    source=source_name,
982                    example_value=example_value,
983                    guard_manager_enum=guard_manager_enum,
984                )
985        elif istype(source, ODictGetItemSource):
986            if isinstance(base_guard_manager, DictGuardManager):
987                assert self.manager_guards_on_keys(base_guard_manager_enum)
988                out = getitem_on_dict_manager(
989                    source,
990                    base_guard_manager,
991                    base_example_value,
992                    example_value,
993                    guard_manager_enum,
994                )
995            else:
996                assert base_guard_manager  # to make mypy happy
997                out = base_guard_manager.dict_getitem_manager(
998                    key=source.index,
999                    source=source_name,
1000                    example_value=example_value,
1001                    guard_manager_enum=guard_manager_enum,
1002                )
1003        elif istype(source, DefaultsSource):
1004            assert base_guard_manager  # to make mypy happy
1005            assert callable(base_example_value)
1006            if not source.is_kw:
1007                out = base_guard_manager.func_defaults_manager(
1008                    source=base_source_name,
1009                    example_value=base_example_value.__defaults__,
1010                    guard_manager_enum=GuardManagerType.GUARD_MANAGER,
1011                ).getitem_manager(
1012                    key=source.idx_key,
1013                    source=source_name,
1014                    example_value=example_value,
1015                    guard_manager_enum=guard_manager_enum,
1016                )
1017            else:
1018                # kwdefauts is a dict, so use a DictGuardManager
1019                kwdefaults = base_example_value.__kwdefaults__
1020                assert base_source_name is not None
1021                kw_source = base_source_name + ".__kwdefaults__"
1022
1023                # kwdefaults is a dict. No need to guard on dict order.
1024                dict_mgr = base_guard_manager.func_kwdefaults_manager(
1025                    source=kw_source,
1026                    example_value=kwdefaults,
1027                    guard_manager_enum=GuardManagerType.GUARD_MANAGER,
1028                )
1029                assert not isinstance(dict_mgr, DictGuardManager)
1030
1031                out = dict_mgr.dict_getitem_manager(
1032                    key=source.idx_key,
1033                    source=source_name,
1034                    example_value=example_value,
1035                    guard_manager_enum=guard_manager_enum,
1036                )
1037        elif istype(source, NumpyTensorSource):
1038            assert base_guard_manager  # to make mypy happy
1039            out = base_guard_manager.lambda_manager(
1040                python_lambda=from_numpy,
1041                source=source_name,
1042                example_value=example_value,
1043                guard_manager_enum=guard_manager_enum,
1044            )
1045        elif istype(source, SubclassAttrListSource):
1046            assert base_guard_manager  # to make mypy happy
1047            out = base_guard_manager.lambda_manager(
1048                python_lambda=lambda x: x.__tensor_flatten__()[0],
1049                source=source_name,
1050                example_value=example_value,
1051                guard_manager_enum=guard_manager_enum,
1052            )
1053        elif istype(source, FlattenScriptObjectSource):
1054            assert base_guard_manager  # to make mypy happy
1055            out = base_guard_manager.lambda_manager(
1056                python_lambda=lambda x: x.__obj_flatten__(),
1057                source=source_name,
1058                example_value=example_value,
1059                guard_manager_enum=guard_manager_enum,
1060            )
1061        elif istype(source, ScriptObjectQualifiedNameSource):
1062            assert base_guard_manager  # to make mypy happy
1063            out = base_guard_manager.lambda_manager(
1064                python_lambda=lambda x: x._type().qualified_name(),
1065                source=source_name,
1066                example_value=example_value,
1067                guard_manager_enum=guard_manager_enum,
1068            )
1069        elif istype(source, AttrProxySource):
1070            assert base_guard_manager  # to make mypy happy
1071            out = base_guard_manager.lambda_manager(
1072                python_lambda=lambda x: x.get_base(),
1073                source=source_name,
1074                example_value=example_value,
1075                guard_manager_enum=guard_manager_enum,
1076            )
1077        elif istype(source, TupleIteratorGetItemSource):
1078            assert base_guard_manager  # to make mypy happy
1079            out = base_guard_manager.tuple_iterator_getitem_manager(
1080                index=source.index,
1081                source=source_name,
1082                example_value=example_value,
1083                guard_manager_enum=guard_manager_enum,
1084            )
1085        elif isinstance(source, ConstDictKeySource):
1086            if not isinstance(base_guard_manager, DictGuardManager):
1087                raise AssertionError(
1088                    "ConstDictKeySource can only work on DictGuardManager"
1089                )
1090            out = base_guard_manager.get_key_manager(
1091                index=source.index,
1092                source=source_name,
1093                example_value=example_value,
1094                guard_manager_enum=guard_manager_enum,
1095            )
1096        elif isinstance(source, WeakRefCallSource):
1097            assert base_guard_manager  # to make mypy happy
1098            out = base_guard_manager.weakref_call_manager(
1099                source=source_name,
1100                example_value=example_value,
1101                guard_manager_enum=guard_manager_enum,
1102            )
1103        else:
1104            raise AssertionError(
1105                f"missing guard manager builder {source} - {source.name()}"
1106            )
1107
1108        self._cached_guard_managers[source.name()] = out
1109        return out
1110
1111    def get_guard_manager(self, guard: Guard):
1112        return self.get_guard_manager_from_source(guard.originating_source)
1113
1114    def add_python_lambda_leaf_guard_to_root(
1115        self,
1116        code_parts,
1117        verbose_code_parts,
1118        closure_vars=CLOSURE_VARS,
1119        is_epilogue=True,
1120    ):
1121        # Adds a lambda leaf guard to the root guard manager. It wraps the
1122        # code_parts in a function object which is then passed on to the leaf
1123        # guard.
1124        make_guard_fn_args = ", ".join(closure_vars.keys())
1125        guard_body, pycode = build_guard_function(code_parts, make_guard_fn_args)
1126        out: Dict[str, Any] = {}
1127        globals_for_guard_fn = {"G": self.scope["G"]}
1128        exec(pycode, globals_for_guard_fn, out)
1129        guard_fn = out["___make_guard_fn"](*closure_vars.values())
1130        assert self.guard_manager  # to make mypy happy
1131        if is_epilogue:
1132            # Epilogue guards are run after all the other guards have finished.
1133            # If epilogue guards contain a getattr or getitem access, one of the
1134            # other guards would fail preventing the epilogue guards to run.
1135            self.guard_manager.root.add_epilogue_lambda_guard(
1136                guard_fn, verbose_code_parts
1137            )
1138        else:
1139            self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts)
1140
1141    # Warning: use this with care!  This lets you access what the current
1142    # value of the value you are guarding on is.  You probably don't want
1143    # to actually durably save this value though (because it's specific
1144    # to this frame!)  Instead, you should be reading out some property
1145    # (like its type) which is what you permanently install into the
1146    # guard code.
1147    def get(self, name: str) -> Any:
1148        return eval(name, self.scope, CLOSURE_VARS)
1149
1150    # Registers the usage of the source name referenced by the
1151    # string (or stored in the Guard) as being guarded upon.  It's important
1152    # to call this before generating some code that makes use of 'guard',
1153    # because without this call, we won't actually bind the variable
1154    # you reference in the actual guard closure (oops!)
1155    def arg_ref(self, guard: Union[str, Guard]) -> str:
1156        name: str
1157        if isinstance(guard, str):
1158            name = guard
1159        else:
1160            name = guard.name
1161        base = strip_getattr_getitem(strip_function_call(name))
1162        if base not in self.argnames:
1163            if re.match(r"[a-zA-Z0-9_]+", base):
1164                if re.match(r"^\d+$", base):
1165                    log.warning("invalid var name: %s", guard)
1166                self.argnames.append(base)
1167
1168        return name
1169
1170    def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn):
1171        attr_source = AttrSource(guard.originating_source, attr_name)
1172        # Copy the stack info
1173        new_guard = Guard(
1174            attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack
1175        )
1176        new_guard.create(self)
1177
1178    # Note: the order of the guards in this file matters since we sort guards on the same object by lineno
1179    def HASATTR(self, guard: Guard):
1180        source = guard.originating_source
1181        if isinstance(source, NNModuleSource):
1182            source = source.base
1183        assert isinstance(source, AttrSource), f"invalid source {guard.name}"
1184        base_source = source.base
1185        base = base_source.name()
1186        attr = source.member
1187
1188        ref = self.arg_ref(base)
1189        val = hasattr(self.get(base), attr)
1190        code = None
1191        if val:
1192            code = f"hasattr({ref}, {attr!r})"
1193        else:
1194            code = f"not hasattr({ref}, {attr!r})"
1195        self._set_guard_export_info(
1196            guard, [code], provided_guarded_object=self.get(base)
1197        )
1198
1199        if config.enable_cpp_guard_manager:
1200            base_manager = self.get_guard_manager_from_source(base_source)
1201            if val:
1202                # Just install a getattr manager. GetAttrGuardAccessor itself
1203                # acts as hasattr guard.
1204                example_value = self.get(source.name())
1205                base_example_value = self.get(base)
1206                guard_manager_enum = self.get_guard_manager_type(source, example_value)
1207
1208                # if the base value is nn.Module, check if we can speedup the
1209                # guard by going through __dict__ attrs.
1210                if (
1211                    isinstance(base_example_value, torch.nn.Module)
1212                    and get_custom_getattr(base_example_value)
1213                    is unpatched_nn_module_getattr
1214                ):
1215                    return self.getattr_on_nn_module(
1216                        source,
1217                        base_manager,
1218                        base_example_value,
1219                        example_value,
1220                        base,
1221                        source.name(),
1222                        guard_manager_enum,
1223                    )
1224                else:
1225                    base_manager.getattr_manager(
1226                        attr=attr,
1227                        source=guard.name,
1228                        example_value=example_value,
1229                        guard_manager_enum=guard_manager_enum,
1230                    )
1231            else:
1232                base_manager.add_no_hasattr_guard(
1233                    attr, get_verbose_code_parts(code, guard)
1234                )
1235        else:
1236            self._produce_guard_code(guard, [code])
1237
1238    def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None:
1239        assert attr is not None
1240        ref = self.arg_ref(guard)
1241        val = self.get(guard.name)
1242        assert isinstance(val, torch.nn.Module)
1243
1244        base_manager = self.get_guard_manager(guard)
1245
1246        mod_dict_source = f"{guard.name}.__dict__"
1247        mod_generic_dict_manager = base_manager.get_generic_dict_manager(
1248            source=mod_dict_source,
1249            example_value=val.__dict__,
1250            guard_manager_enum=GuardManagerType.GUARD_MANAGER,
1251        )
1252
1253        code = f"not ___dict_contains({attr!r}, {ref}.__dict__)"
1254        mod_generic_dict_manager.add_dict_contains_guard(
1255            False, attr, get_verbose_code_parts(code, guard)
1256        )
1257
1258    def TYPE_MATCH(self, guard: Guard) -> None:
1259        # ___check_type_id is same as `id(type(x)) == y`
1260        t = type(self.get(guard.name))
1261        obj_id = self.id_ref(t)
1262        code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
1263        self._set_guard_export_info(guard, [code])
1264
1265        if config.enable_cpp_guard_manager:
1266            self.get_guard_manager(guard).add_type_match_guard(
1267                obj_id, get_verbose_code_parts(code, guard)
1268            )
1269        else:
1270            self._produce_guard_code(guard, [code])
1271
1272    def DICT_VERSION(self, guard: Guard):
1273        # ___check_dict_version is same as `dict_version(x) == y`
1274        ref = self.arg_ref(guard)
1275        val = self.get(guard.name)
1276        version = dict_version(self.get(guard.name))
1277        code = f"___dict_version({ref}) == {version}"
1278        self._set_guard_export_info(guard, [code])
1279
1280        if config.enable_cpp_guard_manager:
1281            # TODO(anijain2305) - Delete this when DictGuardManager uses tags
1282            # for dicts.
1283            self.get_guard_manager(guard).add_dict_version_guard(
1284                val, get_verbose_code_parts(code, guard)
1285            )
1286        else:
1287            self._produce_guard_code(guard, [code])
1288
1289    def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool):
1290        dict_ref = self.arg_ref(guard)
1291
1292        maybe_not = "not " if invert else ""
1293        code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})"
1294        self._set_guard_export_info(guard, [code])
1295
1296        if config.enable_cpp_guard_manager:
1297            self.get_guard_manager(guard).add_dict_contains_guard(
1298                not invert, key, get_verbose_code_parts(code, guard)
1299            )
1300        else:
1301            self._produce_guard_code(guard, [code])
1302
1303    def ID_MATCH(self, guard: Guard):
1304        # ___check_obj_id is same as `id(x) == y`
1305        if isinstance(guard.originating_source, TypeSource):
1306            # optional optimization to produce cleaner/faster guard code
1307            return self.TYPE_MATCH(
1308                Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH)  # type: ignore[arg-type]
1309            )
1310
1311        ref = self.arg_ref(guard)
1312        val = self.get(guard.name)
1313        id_val = self.id_ref(val)
1314        code = f"___check_obj_id({ref}, {id_val})"
1315        self._set_guard_export_info(guard, [code])
1316
1317        if config.enable_cpp_guard_manager:
1318            self.get_guard_manager(guard).add_id_match_guard(
1319                id_val, get_verbose_code_parts(code, guard)
1320            )
1321        else:
1322            self._produce_guard_code(guard, [code])
1323
1324        # Keep track of ID_MATCH'd objects. This will be used to modify the
1325        # cache size logic
1326        if isinstance(guard.originating_source, LocalSource):
1327            # TODO(anijain2305) - This is currently restricted to nn.Module objects
1328            # because many other ID_MATCH'd objects fail - like DeviceMesh.
1329            # Increase the scope of ID_MATCH'd objects.
1330            if isinstance(val, torch.nn.Module):
1331                local_name = guard.originating_source.local_name
1332                weak_id = self.lookup_weakrefs(val)
1333                if weak_id is not None:
1334                    self.id_matched_objs[local_name] = weak_id
1335
1336    def NOT_NONE_MATCH(self, guard: Guard, value=None):
1337        ref = self.arg_ref(guard)
1338        val = self.get(guard.name)
1339        assert isinstance(val, torch.Tensor)
1340        code = f"{ref} is not None"
1341        self._set_guard_export_info(guard, [code])
1342
1343        if config.enable_cpp_guard_manager:
1344            self.get_guard_manager(guard).add_not_none_guard(
1345                get_verbose_code_parts(code, guard)
1346            )
1347        else:
1348            self._produce_guard_code(guard, [code])
1349
1350    def NAME_MATCH(self, guard: Guard):
1351        self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH)
1352
1353    def DATA_PTR_MATCH(self, guard: Guard):
1354        # Add a type check. C++ guard has the type check internally, so only
1355        # enable it for Python guards.
1356        if not config.enable_cpp_guard_manager:
1357            self.TYPE_MATCH(guard)
1358
1359        obj = self.get(guard.name)
1360        code = f"{self.arg_ref(guard)}.data_ptr() == {obj.data_ptr()}"
1361        self._set_guard_export_info(guard, [code])
1362
1363        if config.enable_cpp_guard_manager:
1364            self.get_guard_manager(guard).add_data_ptr_guard(
1365                obj, get_verbose_code_parts(code, guard)
1366            )
1367        else:
1368            self._produce_guard_code(guard, [code])
1369
1370    def DUAL_LEVEL(self, guard: Guard):
1371        # Invalidate dual level if current dual level is different than the one
1372        # in the fx graph
1373        dual_level = torch.autograd.forward_ad._current_level
1374        code = [f"torch.autograd.forward_ad._current_level == {dual_level}"]
1375        self._set_guard_export_info(guard, [code])
1376        if config.enable_cpp_guard_manager:
1377            # TODO(anijain2305) - Consider this moving this guard to C++
1378            forward_ad = torch.autograd.forward_ad
1379
1380            def fn(x):
1381                return forward_ad._current_level == dual_level
1382
1383            assert self.guard_manager  # to make mypy happy
1384            self.guard_manager.root.add_lambda_guard(
1385                fn, get_verbose_code_parts(code, guard)
1386            )
1387        else:
1388            self._produce_guard_code(guard, code)
1389
1390    def FUNCTORCH_STACK_MATCH(self, guard: Guard):
1391        # Invalidate functorch code if current level is different than
1392        # the one when FX graph was generated
1393        cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters()
1394        states = [ci.get_state() for ci in cis]
1395        code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"]
1396        self._set_guard_export_info(guard, code)
1397
1398        if config.enable_cpp_guard_manager:
1399            # TODO(anijain2305) - Consider this moving this guard to C++
1400            compare_fn = torch._functorch.pyfunctorch.compare_functorch_state
1401
1402            def fn(x):
1403                return compare_fn(states)
1404
1405            assert self.guard_manager  # to make mypy happy
1406            self.guard_manager.root.add_lambda_guard(
1407                fn, get_verbose_code_parts(code, guard)
1408            )
1409        else:
1410            self._produce_guard_code(guard, code)
1411
1412    def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard):
1413        value = self.get(guard.name)
1414        original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1])
1415        if hasattr(value, "__metadata_guard__"):
1416            verify_guard_fn_signature(value)
1417
1418            def metadata_checker(x):
1419                return value.__metadata_guard__(
1420                    original_metadata, x.__tensor_flatten__()[1]
1421                )
1422
1423        else:
1424
1425            def metadata_checker(x):
1426                return x.__tensor_flatten__()[1] == original_metadata
1427
1428        global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}"
1429        if config.enable_cpp_guard_manager:
1430            self.get_guard_manager(guard).add_lambda_guard(
1431                metadata_checker, get_verbose_code_parts(global_name, guard)
1432            )
1433        else:
1434            global_scope = self.get("G")
1435            global_scope[global_name] = metadata_checker
1436            code = [f"{global_name}({self.get(guard.name)})"]
1437            self._produce_guard_code(guard, code)
1438
1439    def EQUALS_MATCH(self, guard: Guard):
1440        ref = self.arg_ref(guard)
1441        val = self.get(guard.name)
1442        t = type(val)
1443        if np:
1444            np_types: Tuple[Type[Any], ...] = (
1445                np.int8,
1446                np.int16,
1447                np.int32,
1448                np.int64,
1449                np.uint8,
1450                np.uint16,
1451                np.uint32,
1452                np.uint64,
1453                np.float16,
1454                np.float32,
1455                np.float64,
1456            )
1457        else:
1458            np_types = ()
1459
1460        ok_mutable_types = (list, set)
1461
1462        ok_types = tuple(
1463            common_constant_types
1464            | {
1465                type,
1466                tuple,
1467                frozenset,
1468                slice,
1469                range,
1470                torch.Size,
1471                *np_types,
1472                *ok_mutable_types,
1473            }
1474        )
1475
1476        if torch.distributed.is_available():
1477            from torch.distributed.device_mesh import DeviceMesh
1478            from torch.distributed.tensor.placement_types import (
1479                Partial,
1480                Replicate,
1481                Shard,
1482            )
1483
1484            ok_types = ok_types + (
1485                Shard,
1486                Replicate,
1487                Partial,
1488                DeviceMesh,
1489            )
1490
1491        if istype(val, dict):
1492            assert all(
1493                istype(x, ok_types) for x in itertools.chain(val.keys(), val.values())
1494            )
1495        else:
1496            assert istype(
1497                val,
1498                ok_types,
1499            ), f"Unexpected type {type(val)}, not in {ok_types}"
1500
1501        # Special case for nan because float("nan") == float("nan") evaluates to False
1502        if istype(val, float) and math.isnan(val):
1503            self.TYPE_MATCH(guard)
1504            code = []
1505            code.append(f"__math_isnan({ref})")
1506            self._set_guard_export_info(guard, code)
1507
1508            if config.enable_cpp_guard_manager:
1509                self.get_guard_manager(guard).add_lambda_guard(
1510                    CLOSURE_VARS["__math_isnan"], get_verbose_code_parts(code, guard)
1511                )
1512            else:
1513                self._produce_guard_code(guard, code)
1514            return
1515
1516        # Python math library doesn't support complex nan, so we need to use numpy
1517        if istype(val, complex) and np.isnan(val):
1518            self.TYPE_MATCH(guard)
1519            code = []
1520            code.append(f"__numpy_isnan({ref})")
1521            self._set_guard_export_info(guard, code)
1522
1523            if config.enable_cpp_guard_manager:
1524                self.get_guard_manager(guard).add_lambda_guard(
1525                    CLOSURE_VARS["__numpy_isnan"], get_verbose_code_parts(code, guard)
1526                )
1527            else:
1528                self._produce_guard_code(guard, code)
1529            return
1530
1531        if config.enable_cpp_guard_manager:
1532            # Construct a debug string to put into the c++ equals match guard.
1533            code = [f"{ref} == {val!r}"]
1534            if istype(val, ok_mutable_types):
1535                # C++ guards perform a pointer equality check to speedup guards, but the assumption is that the object
1536                # is mutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the
1537                # pointer equality check.
1538                val = deepcopy(val)
1539            self.get_guard_manager(guard).add_equals_match_guard(
1540                val, get_verbose_code_parts(code, guard)
1541            )
1542            self._set_guard_export_info(guard, code)
1543            return
1544
1545        code = []
1546
1547        # If matching equality against list/tuple, we must also check that
1548        # the internal types match.  (TODO: what about nested lists?)
1549        if istype(val, (list, tuple)):
1550            # NB: SEQUENCE_LENGTH takes care of the outer __check_type_id test
1551            self.SEQUENCE_LENGTH(guard)
1552
1553            for idx, elem in enumerate(val):
1554                code.append(
1555                    f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})"
1556                )
1557        else:
1558            # Add type check to prevent equality check between tensor and non-tensor.
1559            self.TYPE_MATCH(guard)
1560
1561        if istype(val, torch.Size):
1562            val = tuple(val)
1563
1564        # Code object can not be compared against their string representation
1565        # I.e `eval(f"{compile('2+2','','exec')!r}")` raises SyntaxError
1566        assert not istype(val, types.CodeType)
1567
1568        # TODO: It feels like it would be better to just implement our own
1569        # equality test in C that handles all of the necessary type checking
1570        # and NaN tests
1571        code.append(f"{ref} == {val!r}")
1572        self._produce_guard_code(guard, code)
1573        self._set_guard_export_info(guard, code)
1574
1575    def CONSTANT_MATCH(self, guard: Guard):
1576        val = self.get(guard.name)
1577        if istype(val, (bool, type(None), types.CodeType)):
1578            self.ID_MATCH(guard)
1579        else:
1580            self.EQUALS_MATCH(guard)
1581
1582    def NN_MODULE(self, guard: Guard):
1583        self.ID_MATCH(guard)
1584        val = self.get(guard.name)
1585        if hasattr(val, "training"):
1586            assert istype(val.training, bool)
1587            self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH)
1588        else:
1589            exc.unimplemented(f"Guard setup for uninitialized class {type(val)}")
1590
1591    def FUNCTION_MATCH(self, guard: Guard):
1592        """things like torch.add and user defined functions"""
1593        return self.ID_MATCH(guard)
1594
1595    def CLOSURE_MATCH(self, guard: Guard):
1596        """matches a closure by __code__ id."""
1597        val = self.get(guard.name)
1598        # Strictly only want user-defined functions
1599        if type(val) == types.FunctionType and hasattr(val, "__code__"):
1600            self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR)
1601            self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH)
1602        else:
1603            self.FUNCTION_MATCH(guard)
1604
1605    def BUILTIN_MATCH(self, guard: Guard):
1606        return self.FUNCTION_MATCH(guard)
1607
1608    def PYMODULE_MATCH(self, guard: Guard):
1609        return self.FUNCTION_MATCH(guard)
1610
1611    def SEQUENCE_LENGTH(self, guard):
1612        # This guard is used to check lenght of PySequence objects like list,
1613        # tuple, collections.deque etc
1614        ref = self.arg_ref(guard)
1615        value = self.get(guard.name)
1616        t = type(value)
1617
1618        if not (config.enable_cpp_guard_manager and isinstance(value, dict)):
1619            # C++ DICT_LENGTH checks for type
1620            self.TYPE_MATCH(guard)
1621
1622        code = []
1623        if len(value) == 0:
1624            code.append(f"not {ref}")
1625        else:
1626            code.append(f"len({ref}) == {len(value)}")
1627
1628        self._set_guard_export_info(guard, code)
1629        if config.enable_cpp_guard_manager:
1630            if isinstance(value, dict):
1631                self.get_guard_manager(guard).add_dict_length_check_guard(
1632                    len(value), get_verbose_code_parts(code, guard)
1633                )
1634            else:
1635                self.get_guard_manager(guard).add_length_check_guard(
1636                    len(value), get_verbose_code_parts(code, guard)
1637                )
1638        else:
1639            self._produce_guard_code(guard, code)
1640
1641    def TUPLE_ITERATOR_LEN(self, guard):
1642        ref = self.arg_ref(guard)
1643        value = self.get(guard.name)
1644        t = type(value)
1645
1646        if not config.enable_cpp_guard_manager:
1647            # C++ guard already checks the type
1648            self.TYPE_MATCH(guard)
1649
1650        code = []
1651        code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}")
1652        self._set_guard_export_info(guard, code)
1653
1654        if config.enable_cpp_guard_manager:
1655            t = type(value)
1656            obj_id = self.id_ref(t)
1657
1658            self.get_guard_manager(guard).add_tuple_iterator_length_guard(
1659                tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard)
1660            )
1661        else:
1662            self._produce_guard_code(guard, code)
1663
1664    # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards
1665    def DUPLICATE_INPUT(self, guard, source_b):
1666        ref_a = self.arg_ref(guard)
1667        ref_b = self.arg_ref(source_b.name())
1668
1669        if is_from_optimizer_source(
1670            guard.originating_source
1671        ) or is_from_optimizer_source(source_b):
1672            return
1673
1674        code = [f"{ref_b} is {ref_a}"]
1675        self._set_guard_export_info(guard, code)
1676
1677        if config.enable_cpp_guard_manager:
1678            # Check that the guard has not been inserted already
1679            key = (ref_a, ref_b)
1680            if key in self._cached_duplicate_input_guards:
1681                return
1682            self._cached_duplicate_input_guards.add((ref_a, ref_b))
1683            self._cached_duplicate_input_guards.add((ref_b, ref_a))
1684
1685            install_object_aliasing_guard(
1686                self.get_guard_manager(guard),
1687                self.get_guard_manager_from_source(source_b),
1688                get_verbose_code_parts(code, guard),
1689            )
1690        else:
1691            self._produce_guard_code(guard, code)
1692
1693    def DICT_KEYS(self, guard):
1694        # Guard on the keys and their order
1695        ref = self.arg_ref(guard)
1696        value = self.get(guard.name)
1697        t = type(value)
1698
1699        self.TYPE_MATCH(guard)
1700        code = []
1701        any_key_is_id = any(key_is_id(k) for k in value.keys())
1702        const_keys_repr = dict_keys_repr(
1703            key_to_id(value),
1704            local=is_from_local_source(guard.originating_source),
1705        )
1706        if any_key_is_id:
1707            code.append(f"___key_to_id({ref}) == {const_keys_repr}")
1708        else:
1709            code.append(f"list({ref}.keys()) == {const_keys_repr}")
1710
1711        self._set_guard_export_info(guard, code)
1712        if config.enable_cpp_guard_manager:
1713            if self.requires_key_order_guarding(guard.originating_source):
1714                self.guard_on_dict_keys_and_order(value, guard)
1715            else:
1716                self.guard_on_dict_keys_and_ignore_order(value, guard)
1717        else:
1718            self._produce_guard_code(guard, code)
1719
1720    def WEAKREF_ALIVE(self, guard):
1721        code = [f"{self.arg_ref(guard)} is not None"]
1722
1723        self._set_guard_export_info(guard, code)
1724        if config.enable_cpp_guard_manager:
1725            self.get_guard_manager(guard).add_not_none_guard(
1726                get_verbose_code_parts(code, guard)
1727            )
1728        else:
1729            self._produce_guard_code(guard, code)
1730
1731    def DICT_CONST_KEYS(self, guard):
1732        """Constant keys match"""
1733        ref = self.arg_ref(guard)
1734        value = self.get(guard.name)
1735        t = type(value)
1736
1737        if not config.enable_cpp_guard_manager:
1738            # DictGuardManager supports TYPE_MATCH internally
1739            self.TYPE_MATCH(guard)
1740
1741        code = []
1742        code.append(f"list({ref}.keys()) == {list(value.keys())!r}")
1743        self._set_guard_export_info(guard, code)
1744
1745        if config.enable_cpp_guard_manager:
1746            if self.requires_key_order_guarding(guard.originating_source):
1747                self.guard_on_dict_keys_and_order(value, guard)
1748            else:
1749                self.guard_on_dict_keys_and_ignore_order(value, guard)
1750        else:
1751            self._produce_guard_code(guard, code)
1752
1753    def EMPTY_NN_MODULE_HOOKS_DICT(self, guard):
1754        """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards"""
1755        if config.skip_nnmodule_hook_guards:
1756            # This is unsafe if you add/remove a hook on nn module variable
1757            return
1758        self.SEQUENCE_LENGTH(guard)
1759
1760    def OBJECT_MUTATION(self, guard: Guard):
1761        mutation_guard.watch(self.get(guard.name), self.check_fn_manager)
1762
1763    def GRAD_MODE(self, guard: Guard):
1764        pass  # we always guard on this via GlobalStateGuard()
1765
1766    def DETERMINISTIC_ALGORITHMS(self, guard: Guard):
1767        pass  # we always guard on this via GlobalStateGuard()
1768
1769    def TORCH_FUNCTION_STATE(self, guard: Guard):
1770        pass  # we always guard on this via GlobalStateGuard()
1771
1772    def FSDP_TRAINING_STATE(self, guard: Guard):
1773        pass  # we always guard on this via GlobalStateGuard()
1774
1775    def DEFAULT_DEVICE(self, guard: Guard):
1776        """Guard on CURRENT_DEVICE per torch.utils._device"""
1777        assert guard.source is GuardSource.GLOBAL
1778        import torch.utils._device as m
1779
1780        code = [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"]
1781        self._set_guard_export_info(guard, code)
1782
1783        if config.enable_cpp_guard_manager:
1784            self.get_guard_manager(guard).add_default_device_guard(
1785                get_verbose_code_parts(code, guard)
1786            )
1787        else:
1788            self._produce_guard_code(guard, code)
1789
1790    def SHAPE_ENV(self, guard: Guard):
1791        # Let's handle ShapeEnv guards.  To do this, we will resolve
1792        # shape variables to sources from tracked_fakes.  This must happen after
1793        # tensor checks.
1794        assert guard.name == ""
1795        output_graph = self.check_fn_manager.output_graph
1796        # NB: self.output_graph can be None in the debug_nops tests
1797        fs = output_graph.tracked_fakes
1798        input_contexts = [a.symbolic_context for a in fs]
1799
1800        def get_sources(t_id, dim):
1801            # Looks up base sources mapped to a tensor id and uses them to create
1802            # sources for the corresponding tensor dimension.
1803            return [
1804                TensorPropertySource(source, TensorProperty.SIZE, dim)
1805                for source in output_graph.tracked_fakes_id_to_source[t_id]
1806            ]
1807
1808        if output_graph.export_constraints:
1809            names: Dict[str, Tuple[int, int]] = {}
1810            source_pairs: List[Tuple[Source, Source]] = []
1811            derived_equalities: List[  # type: ignore[type-arg]
1812                Tuple[Source, Union[Source, Symbol], Callable]
1813            ] = []
1814            phantom_symbols: Dict[str, Symbol] = {}
1815            for constraint in output_graph.export_constraints:
1816                if constraint.t_id in output_graph.tracked_fakes_id_to_source:
1817                    torch.export.dynamic_shapes._process_equalities(
1818                        constraint,
1819                        get_sources,
1820                        output_graph.shape_env,
1821                        names,
1822                        source_pairs,
1823                        derived_equalities,
1824                        phantom_symbols,
1825                    )
1826                else:
1827                    log.warning("Untracked tensor used in export constraints")
1828            equalities_inputs = EqualityConstraint(
1829                source_pairs=source_pairs,
1830                derived_equalities=derived_equalities,
1831                phantom_symbols=list(phantom_symbols.values()),
1832                warn_only=False,
1833            )
1834        else:
1835            equalities_inputs = None
1836        guards = output_graph.shape_env.produce_guards(
1837            [a.fake for a in fs],
1838            [a.source for a in fs],
1839            input_contexts=input_contexts,
1840            equalities_inputs=equalities_inputs,
1841            source_ref=self.source_ref,
1842            # Export keeps static.
1843            ignore_static=(not self.check_fn_manager.output_graph.export),
1844        )
1845        # When exporting, we may work with the shape constraints some more in
1846        # postprocessing, so don't freeze yet
1847        if not self.check_fn_manager.output_graph.export:
1848            output_graph.shape_env.freeze()
1849
1850        for shape_guard in guards:
1851            self._set_guard_export_info(guard, [shape_guard])
1852
1853        if config.enable_cpp_guard_manager:
1854            # Install all the symbolic guards in one lambda guard. These are run
1855            # at the very end of the RootGuardManager via epilogue guards.
1856            # TODO(anijain2305,williamwen42) - Consider moving this to C++.
1857            code_parts = guards
1858            self.add_python_lambda_leaf_guard_to_root(
1859                code_parts,
1860                get_verbose_code_parts(code_parts, guard),
1861                closure_vars={**SYMPY_INTERP, **CLOSURE_VARS},
1862            )
1863        else:
1864            for shape_guard in guards:
1865                self._produce_guard_code(guard, [shape_guard], shape_env=True)
1866
1867    def TENSOR_MATCH(self, guard: Guard, value=None):
1868        # For FSDP modules, we can skip guards on nn module tensors because FSDP
1869        # eager assumes that the params are unchanged once the model is wrapped.
1870        if guard.is_fsdp_module():
1871            return
1872
1873        # For tensors that are part of the Dynamo extracted Fx graph module, an
1874        # ID_MATCH suffices. Once we turn on inline_inbuilt_nn_modules, these
1875        # will be lifted as inputs and have a TENSOR_MATCH guard.
1876        # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads
1877        # to a new tensor everytime and therefore id differs.
1878        if (
1879            guard.is_specialized_nn_module()
1880            and not isinstance(guard.originating_source, NumpyTensorSource)
1881        ) or match_on_id_for_tensor(guard):
1882            self.ID_MATCH(guard)
1883        else:
1884            if isinstance(value, TensorWeakRef):
1885                value = value()
1886
1887            value = value if value is not None else self.get(guard.name)
1888            assert isinstance(value, torch.Tensor)
1889
1890            tensor_name = self.arg_ref(guard)
1891            # [Note - On Export Tensor Guards]
1892            #
1893            # In eager mode, tensor guards are evaluated through C++, in guards.cpp
1894            # see [Note - On Eager Tensor Guards] for more info.
1895            #
1896            # In export mode, we instead maintain parallel logic between C++ and python
1897            # here, with an exception of checking the dispatch key - with the idea that a dispatch key
1898            # is an entirely runtime notion that would make no sense to keep in an exported graph.
1899            #
1900            # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although
1901            # not entirely true.
1902            # For example, suppose one of the input tensors had the negative dispatch key.
1903            # You should end up with a graph that is specialized for tensors that have a negative dispatch key.
1904            # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated.
1905            # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't
1906            # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key.
1907            # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported
1908            # subset of keys during export.
1909            #
1910            # The list of tensor fields and calls we care about can be found in `terms` below.
1911            # TODO(voz): We are missing storage offset in all our tensor guards?
1912            code: List[str] = []
1913            if self.check_fn_manager.output_graph.export:
1914                self.TYPE_MATCH(guard)
1915                terms = [
1916                    "dtype",
1917                    "device",
1918                    "requires_grad",
1919                    "ndimension()",
1920                ]
1921
1922                for term in terms:
1923                    real_value = self.get(tensor_name + "." + term)
1924                    if istype(real_value, (torch.device, torch.dtype)):
1925                        # copy pasted from EQUALS_MATCH
1926                        code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}")
1927                    else:
1928                        code.append(f"{tensor_name}.{term} == {real_value}")
1929            else:
1930                self.tensor_check_examples.append(value)
1931                self.tensor_check_names.append(tensor_name)
1932                self.tensor_check_guards.append(guard)
1933
1934                if config.enable_cpp_guard_manager:
1935                    guard_manager = self.get_guard_manager(guard)
1936                    # Keep track of all the tensor guard managers to insert
1937                    # NoAliasing check at the end.
1938                    self.tensor_check_guard_managers.append(guard_manager)
1939
1940                    output_graph = self.check_fn_manager.output_graph
1941                    metadata = output_graph.input_source_to_sizes_strides[
1942                        guard.originating_source
1943                    ]
1944                    size = convert_to_concrete_values(metadata["size"])
1945                    stride = convert_to_concrete_values(metadata["stride"])
1946
1947                    verbose_code_parts = get_verbose_code_parts(
1948                        get_tensor_guard_code_part(value, tensor_name, size, stride),
1949                        guard,
1950                    )
1951                    guard_manager.add_tensor_match_guard(
1952                        value,
1953                        size,
1954                        stride,
1955                        tensor_name,
1956                        verbose_code_parts,
1957                    )
1958
1959            # A frame is valid for reuse with dynamic dimensions if the new
1960            # (user-requested) dynamic dimensions are a subset of the old
1961            # (already compiled) dynamic dimensions.
1962            #
1963            # It's a little non-obvious why you'd want this: in particular,
1964            # if an already compiled frame matches all of the guards, why
1965            # not just use it, why force a recompile?
1966            #
1967            # We force it for two reasons:
1968            #
1969            #   - The user *required* us to compile with a new dynamic dimension,
1970            #     we should not ignore that and serve up the old, specialized
1971            #     frame.  Listen to the user!
1972            #
1973            #   - In fact, we are obligated to *raise an error* if we fail to
1974            #     make the requested dimension dynamic.  If we don't
1975            #     recompile, we can't tell if that dimension can actually be
1976            #     made dynamic.
1977            #
1978            # If the new dynamic dims are a subset of the old, we already know
1979            # we can make them dynamic (since we made them dynamic in old).
1980            # This is slightly unsound, because maybe your input size is
1981            # [s0, s0, s1] and so you can do it dynamic if you say dynamic
1982            # dims {0, 1, 2} but you can't if you only do {0, 2} (because now
1983            # the second s0 is specialized).  But we're not entirely sure if
1984            # this is a good idea anyway lol... (if you want to try removing
1985            # this logic, be my guest!  -- ezyang 2024)
1986            #
1987            assert guard.source is not None
1988            static, reason = tensor_always_has_static_shape(
1989                value, is_tensor=True, tensor_source=guard.originating_source
1990            )
1991
1992            if not static:
1993                if hasattr(value, "_dynamo_dynamic_indices"):
1994                    dynamic_indices = value._dynamo_dynamic_indices
1995                    code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)"  # noqa: B950
1996                    code.append(code_part)
1997                    if config.enable_cpp_guard_manager:
1998                        self.get_guard_manager(guard).add_dynamic_indices_guard(
1999                            dynamic_indices, get_verbose_code_parts(code_part, guard)
2000                        )
2001                # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of
2002                # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled.
2003                else:
2004                    code_part = (
2005                        f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False"
2006                    )
2007                    code.append(code_part)
2008                    if config.enable_cpp_guard_manager:
2009                        self.get_guard_manager(guard).add_no_hasattr_guard(
2010                            "_dynamo_dynamic_indices",
2011                            get_verbose_code_parts(code_part, guard),
2012                        )
2013            if len(code) > 0:
2014                self._set_guard_export_info(guard, code)
2015                if not config.enable_cpp_guard_manager:
2016                    self._produce_guard_code(guard, code)
2017
2018    # A util that appends guarded code
2019    def _produce_guard_code(self, guard, code_list, shape_env=False):
2020        assert not config.enable_cpp_guard_manager
2021        if shape_env:
2022            self.shape_env_code.append(GuardCodeList(code_list, guard))
2023        else:
2024            self.code.append(GuardCodeList(code_list, guard))
2025
2026    # A util that in the case of export, adds data onto guards
2027    def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None):
2028        # WARNING: It is important that cur_frame/caller do NOT stay in
2029        # the current frame, because they will keep things live longer
2030        # than they should.  See TestMisc.test_release_module_memory
2031        cur_frame = currentframe()
2032        assert cur_frame is not None
2033        caller = cur_frame.f_back
2034        del cur_frame
2035        assert caller is not None
2036        func_name = getframeinfo(caller)[2]
2037        del caller
2038        # We use func_name for export, so might as well get a nice defensive check out of it
2039        assert func_name in dir(
2040            self.__class__
2041        ), f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}"
2042
2043        # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD)
2044        if provided_guarded_object is None:
2045            name_valid = guard.name is not None and guard.name != ""
2046
2047            guarded_object = self.get(guard.name) if name_valid else None
2048        else:
2049            guarded_object = provided_guarded_object
2050
2051        guarded_object_type = (
2052            weakref.ref(type(guarded_object)) if guarded_object is not None else None
2053        )
2054        obj_ref = None
2055        # Not necessary to have weakref for Enum type, but there is a bug that
2056        # makes hasattr(guarded_object.__class__, "__weakref__") return True.
2057        if hasattr(guarded_object.__class__, "__weakref__") and not isinstance(
2058            guarded_object, enum.Enum
2059        ):
2060            obj_ref = weakref.ref(guarded_object)
2061
2062        guard.set_export_info(
2063            func_name,
2064            guarded_object_type,
2065            code_list,
2066            obj_ref,
2067        )
2068
2069
2070# Common Sub-Expression Elimination for Python expressions.
2071#
2072# There are 2 steps to this pass:
2073#     1. Count the frequency of each sub-expression (i.e. inner
2074#        node in the AST tree)
2075#
2076#     2. Replace those that occur more than once by a fresh variable 'v'.
2077#        'v' will be defined in the 'preface' list (output argument to
2078#        'NodeTransformer')
2079#
2080# NB: the use of 'ast.unparse' while visiting the nodes makes this pass
2081# quadratic on the depth of the tree.
2082#
2083# NB: this pass creates a new variable for each AST node that is repeated
2084# more than 'USE_THRESHOLD'. e.g. if 'a.b.c.d' is used 10 times, 'a.b.c'
2085# and 'a.b' are also used 10 times. So, there will be a new variable for
2086# each of them.
2087class PyExprCSEPass:
2088    # Maximum number of times a given expression can be used without being
2089    # replaced by a fresh variable.
2090    USE_THRESHOLD = 1
2091
2092    # Ad-Hoc: AST nodes this pass focuses on.
2093    ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript)
2094
2095    @dataclasses.dataclass
2096    class Config:
2097        expr_count: Dict[str, int]
2098        expr_to_name: Dict[str, str]
2099
2100    class ExprCounter(ast.NodeVisitor):
2101        def __init__(self, config: PyExprCSEPass.Config) -> None:
2102            self._config = config
2103
2104        def visit(self, node: ast.AST) -> Any:
2105            if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES):
2106                self._config.expr_count[_ast_unparse(node)] += 1
2107            super().visit(node)
2108
2109    class Replacer(ast.NodeTransformer):
2110        def __init__(
2111            self,
2112            config: PyExprCSEPass.Config,
2113            gen_name: Callable[[], str],
2114        ) -> None:
2115            super().__init__()
2116            self._config = config
2117            self._gen_name = gen_name
2118            self.preface: List[str] = []
2119
2120        def visit(self, node: ast.AST) -> Any:
2121            if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES):
2122                expr = _ast_unparse(node)
2123
2124                # Replacement only occurs if a given expression is used more
2125                # than once.
2126                if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD:
2127                    if expr not in self._config.expr_to_name:
2128                        # Parent 'visit' is called so that we CSE the inner expressions first.
2129                        #
2130                        # The resulting expression is used as right-hand-side of the variable
2131                        # assignment. i.e. we are CSE-ing the children before the parents.
2132                        #
2133                        # Indexing still uses the old 'node', since that's what was counted
2134                        # by the 'NodeVisitor'.
2135                        node_ = super().visit(node)
2136                        expr_ = _ast_unparse(node_)
2137                        var_name = self._gen_name()
2138                        self.preface.append(f"{var_name} = {expr_}")
2139                        self._config.expr_to_name[expr] = var_name
2140                    else:
2141                        var_name = self._config.expr_to_name[expr]
2142                    return ast.Name(var_name, ast.Load())
2143
2144            return super().visit(node)
2145
2146    def __init__(self) -> None:
2147        self._counter = 0
2148        self._config = self.Config(
2149            expr_count=collections.defaultdict(lambda: 0), expr_to_name={}
2150        )
2151
2152    def _new_var(self, prefix: str = "_var") -> str:
2153        name = f"{prefix}{self._counter}"
2154        self._counter += 1
2155        return name
2156
2157    def count(self, exprs: List[str]) -> None:
2158        counter = self.ExprCounter(self._config)
2159        for e in exprs:
2160            try:
2161                counter.visit(ast.parse(e))
2162            except SyntaxError as ex:
2163                log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e)
2164                raise
2165
2166    def replace(self, expr: str) -> Tuple[List[str], str]:
2167        replacer = self.Replacer(self._config, self._new_var)
2168        new_node = replacer.visit(ast.parse(expr))
2169        return replacer.preface, _ast_unparse(new_node)
2170
2171
2172def must_add_nn_module_guards(guard):
2173    # For config.guard_nn_modules=False, we can skip all the guards that
2174    # originate from inside of nn module except for a few categories.
2175    return (
2176        # Guard for defaults
2177        isinstance(guard.originating_source, DefaultsSource)
2178        # Guard using dict tags if the config flag is set
2179        or (
2180            config.guard_nn_modules_using_dict_tags
2181            and guard.create_fn is GuardBuilder.NN_MODULE
2182        )
2183    )
2184
2185
2186class DeletedGuardFn:
2187    pass
2188
2189
2190# NB: Naively, you'd expect this to only be a function that produces
2191# the callable that constitutes the guard.  However, there is some
2192# delicate handling for invalidating this check function when the
2193# locals/globals get invalidated, so there's some extra state
2194# we have to hold in this manager class.
2195class CheckFunctionManager:
2196    def __init__(
2197        self,
2198        output_graph=None,
2199        guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
2200    ):
2201        guards = output_graph.guards if output_graph else None
2202        self._weakrefs: Dict[int, ReferenceType[object]] = {}
2203        self.guard_manager = None
2204        if config.enable_cpp_guard_manager:
2205            self.guard_manager = GuardManager()
2206        self.output_graph = output_graph
2207        w_builder = None
2208
2209        self.torch_function_mode_stack = (
2210            output_graph.torch_function_mode_stack if output_graph else None
2211        )
2212
2213        def source_ref(source):
2214            guard_source = source.guard_source()
2215            if guard_source is GuardSource.CONSTANT:
2216                # No need to track constants
2217                return source.name()
2218            assert w_builder
2219            r_builder = w_builder()
2220            assert r_builder is not None
2221            return r_builder.arg_ref(source.name())
2222
2223        builder = GuardBuilder(
2224            self.id_ref,
2225            source_ref,
2226            self.lookup_weakrefs,
2227            output_graph.local_scope,
2228            output_graph.global_scope,
2229            self.guard_manager,
2230            self,
2231        )
2232
2233        # Break retain cycle. See test_release_scope_memory
2234        def cleanup_builder(weak_b):
2235            b = weak_b()
2236            if b:
2237                b.scope = None
2238
2239        # Break retain cycle. See test_release_input_memory
2240        w_builder = weakref.ref(builder, cleanup_builder)
2241
2242        guard_on_nn_modules = config.guard_nn_modules and justknobs_check(
2243            "pytorch/compiler:guard_nn_modules"
2244        )
2245
2246        if not justknobs_check("pytorch/compiler:guard_nn_modules"):
2247            log.warning("guard_nn_modules is turned off using justknobs killswitch")
2248
2249        for guard in sorted(guards or [], key=Guard.sort_key):
2250            if (
2251                not guard_on_nn_modules
2252                and guard.is_specialized_nn_module()
2253                # Default func args must be guarded on.
2254                # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
2255                and "__defaults__" not in guard.name
2256                and "__kwdefaults__" not in guard.name
2257                and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
2258            ):
2259                continue
2260
2261            guard.create(builder)
2262
2263        self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
2264
2265        # Keep track of weak references of objects with ID_MATCH guard. This
2266        # info is stored alongside optimized_code and check_fn and is used to
2267        # limit the number of cache entries with same ID_MATCH'd object.
2268        # TODO(anijain2305) - Currently this information is stored as an attr on
2269        # the check_fn itself to avoid changing CacehEntry datastructure in
2270        # eval_frame.c. In future, we should probably replace check_fn with a
2271        # queryable data structure such that this information is already present
2272        # in some form.
2273        self.check_fn.id_matched_objs = builder.id_matched_objs
2274
2275        if config.enable_cpp_guard_manager:
2276            # TODO: don't do the string rep, do something more structured here
2277            torch._logging.trace_structured(
2278                "dynamo_cpp_guards_str", payload_fn=lambda: str(self.guard_manager)
2279            )
2280            guards_log.debug("%s", self.guard_manager)
2281            assert self.guard_manager  # to make mypy happy
2282            self.guard_manager.id_matched_objs = builder.id_matched_objs
2283            self.check_fn = self.guard_manager
2284
2285            # Check that the guard returns True. False means that we will always
2286            # recompile.
2287            # TODO(anijain2305, ydwu4) - Skipping export because of following test
2288            # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
2289            if not output_graph.export:
2290                if not self.guard_manager.check(output_graph.local_scope):
2291                    reasons = get_guard_fail_reason_helper(
2292                        self.guard_manager,  # type: ignore[arg-type]
2293                        output_graph.local_scope,
2294                        CompileContext.current_compile_id(),
2295                    )
2296                    raise AssertionError(f"Guard check failed: {reasons}")
2297
2298        # NB - We have to very careful of cleaning up here. Because of the
2299        # invalidate function, we can create a weakref finalizer that keeps
2300        # `self` alive for very long. Sometimes by mistake, we can run
2301        # invalidate for a type/object (check id_ref method) that Python can
2302        # leak by design, preventing us from calling the finalizer. In that
2303        # case, the `self` will be alive even though the cache entry will be
2304        # deleted (check invalidate method), which can cause a memory leak,
2305        # e.g., not setting output_graph = None can keep hold of nn_modules.
2306        self._weakrefs.clear()
2307        self.output_graph = None
2308
2309    def compile_check_fn(self, builder, guards_out, guard_fail_fn):
2310        # see parallel handling of ".0" / "___implicit0" in _eval_frame.c
2311        largs = builder.argnames
2312        largs += ["**___kwargs_ignored"]
2313
2314        guards_log.debug("GUARDS:")
2315
2316        code_parts = []
2317        verbose_code_parts = []
2318        structured_guard_fns: list[Callable[[], dict[str, Any]]] = []
2319
2320        torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard(
2321            self.torch_function_mode_stack
2322        )
2323
2324        if config.enable_cpp_guard_manager:
2325            from .variables.torch_function import IGNORED_MODES
2326
2327            # Insert the global_state guard
2328            assert self.guard_manager  # to make mypy happy
2329            self.guard_manager.root.add_global_state_guard(["___check_global_state()"])
2330
2331            self.guard_manager.root.add_torch_function_mode_stack_guard(
2332                self.torch_function_mode_stack,
2333                list(IGNORED_MODES),
2334                ["___check_torch_function_mode_stack()"],
2335            )
2336            # Clear references to torch_function modes held in the list
2337            self.torch_function_mode_stack = None
2338        else:
2339            # Don't report this guard, it's always the same, useless!
2340            global_guard = "___check_global_state()"
2341            code_parts.append(global_guard)
2342            verbose_code_parts.append(global_guard)
2343
2344            tf_mode_stack_guard = "___check_torch_function_mode_stack()"
2345            code_parts.append(tf_mode_stack_guard)
2346            verbose_code_parts.append(tf_mode_stack_guard)
2347
2348        def add_code_part(code_part, guard, log_only=False):
2349            verbose_code_part = get_verbose_code_part(code_part, guard)
2350            guards_log.debug("%s", verbose_code_part)
2351
2352            structured_guard_fns.append(
2353                lambda: {
2354                    "code": code_part,
2355                    "stack": structured.from_traceback(guard.stack.summary())
2356                    if guard.stack
2357                    else None,
2358                    "user_stack": structured.from_traceback(guard.user_stack)
2359                    if guard.user_stack
2360                    else None,
2361                }
2362            )
2363
2364            if verbose_guards_log.isEnabledFor(logging.DEBUG):
2365                maybe_stack = ""
2366                maybe_user_stack = ""
2367                if guard is not None:
2368                    if guard.stack:
2369                        maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}"
2370                    if guard.user_stack:
2371                        maybe_user_stack = (
2372                            f"\nUser stack:\n{''.join(guard.user_stack.format())}"
2373                        )
2374                verbose_guards_log.debug(
2375                    "Guard: %s%s%s",
2376                    code_part,
2377                    maybe_stack,
2378                    maybe_user_stack,
2379                )
2380
2381            if not log_only:
2382                code_parts.append(code_part)
2383                verbose_code_parts.append(verbose_code_part)
2384
2385        seen = set()
2386        for gcl in builder.code:
2387            for code in gcl.code_list:
2388                if code not in seen:
2389                    # If Cpp guard manager is enabled, we don't need to add to
2390                    # code_parts.
2391                    add_code_part(code, gcl.guard, config.enable_cpp_guard_manager)
2392                    seen.add(code)
2393
2394        tensor_check_names = builder.tensor_check_names
2395        check_tensors_fn = None
2396        check_tensors_verbose_fn = None
2397        if tensor_check_names and not config.enable_cpp_guard_manager:
2398            tensor_check_guards = builder.tensor_check_guards
2399            assert (
2400                not self.output_graph.export
2401            ), "Illegal to set tensor_check_names in export."
2402            tensor_check_examples = builder.tensor_check_examples
2403
2404            dynamic_dims_sizes = []
2405            dynamic_dims_strides = []
2406            for t, g in zip(tensor_check_examples, tensor_check_guards):
2407                metadata = self.output_graph.input_source_to_sizes_strides[
2408                    g.originating_source
2409                ]
2410                dynamic_dims_sizes.append(convert_to_concrete_values(metadata["size"]))
2411                dynamic_dims_strides.append(
2412                    convert_to_concrete_values(metadata["stride"])
2413                )
2414
2415            tensor_guards = TensorGuards(
2416                *tensor_check_examples,
2417                dynamic_dims_sizes=dynamic_dims_sizes,
2418                dynamic_dims_strides=dynamic_dims_strides,
2419            )
2420            check_tensors_fn = tensor_guards.check
2421            check_tensors_verbose_fn = tensor_guards.check_verbose
2422            tensor_check_args = ", ".join(
2423                tensor_check_names + ["tensor_check_names=tensor_check_names"]
2424            )
2425            # Do this manually, to un-stagger the guards in log message
2426            code_parts.append(f"___check_tensors({tensor_check_args})")
2427            verbose_code_parts.append(f"___check_tensors({tensor_check_args})")
2428
2429            for i, name in enumerate(tensor_check_names):
2430                # This is a copy of what guards.cpp checks against
2431                # Keep this in sync with TensorCheck constructor
2432                t = tensor_check_examples[i]
2433                sizes = dynamic_dims_sizes[i]
2434                strides = dynamic_dims_strides[i]
2435                code_part = get_tensor_guard_code_part(t, name, sizes, strides)
2436                add_code_part(code_part, tensor_check_guards[i], log_only=True)
2437
2438        if len(tensor_check_names) > 1 and config.enable_cpp_guard_manager:
2439            # Install tensor aliasing guard. TENSOR_MATCH guards are already
2440            # installed for cpp guard manager.
2441            install_no_tensor_aliasing_guard(
2442                builder.tensor_check_guard_managers,
2443                tensor_check_names,
2444                ["check_no_aliasing(" + ", ".join(tensor_check_names) + ")"],
2445            )
2446
2447        aotautograd_guards: List[GuardEnvExpr] = (
2448            self.output_graph.tracing_context.guards_context.aotautograd_guards
2449            if self.output_graph
2450            else []
2451        )
2452
2453        # TODO(anijain2305) - There is a duplicate logic in Dynamo to find
2454        # aliased input tensors. So most probably we don't need this here.
2455        # Revisit.
2456        for guard in aotautograd_guards:
2457            if isinstance(guard, DuplicateInputs):
2458                source_a = guard.input_source_a
2459                source_b = guard.input_source_b
2460                code_part = f"{source_a.name()} is {source_b.name()}"
2461                if config.enable_cpp_guard_manager:
2462                    install_object_aliasing_guard(
2463                        builder.get_guard_manager_from_source(source_a),
2464                        builder.get_guard_manager_from_source(source_b),
2465                        [code_part],
2466                    )
2467                add_code_part(code_part, None, config.enable_cpp_guard_manager)
2468            else:
2469                raise RuntimeError(f"Unknown GuardEnvExpr: {guard}")
2470
2471        # TODO: the "guard" here is actually just the top level SHAPE_ENV
2472        # which is useless.  Get ShapeEnv to pass in more provenance.
2473        for gcl in builder.shape_env_code:
2474            for code in gcl.code_list:
2475                # Shape env guards are already added for CPP guard manager in
2476                # SHAPE_ENV implementation.
2477                add_code_part(code, gcl.guard, config.enable_cpp_guard_manager)
2478
2479        # OK, all done generating guards
2480        if structured_guard_fns:
2481            torch._logging.trace_structured(
2482                "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
2483            )
2484
2485        global_state = convert_frame.initial_global_state
2486        if global_state is None:
2487            # we should only hit this case in NopTests()
2488            global_state = convert_frame.GlobalStateGuard()
2489        closure_vars = {
2490            "___check_tensors": check_tensors_fn,
2491            "___check_tensors_verbose": check_tensors_verbose_fn,
2492            "___check_global_state": global_state.check,
2493            "___check_torch_function_mode_stack": torch_function_mode_stack_check_fn,
2494            "tensor_check_names": tensor_check_names,
2495            **SYMPY_INTERP,
2496            **CLOSURE_VARS,
2497        }
2498
2499        globals_for_guard_fn = {"G": builder.scope["G"]}
2500        if config.enable_cpp_guard_manager:
2501            # Guard manager construction is complete
2502            assert self.guard_manager  # to make mypy happy
2503            # TODO (anijain2305) - When enable_cpp_guard_manager is ON by
2504            # default, change the guard_fn name to be guard_manager everywhere
2505            # to avoid confusion.
2506            guard_fn = self.guard_manager
2507            # Ensure we did not miss to insert a guard in cpp guard manager.
2508            assert len(code_parts) == 0
2509        else:
2510            unique_code_parts = list(unique(code_parts))
2511            make_guard_fn_args = ", ".join(closure_vars.keys())
2512            guard_body, pycode = build_guard_function(
2513                unique_code_parts, make_guard_fn_args
2514            )
2515
2516            if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1":
2517                print("GUARDS\n", guard_body)
2518
2519            out: Dict[str, Any] = {}
2520
2521            # We don't put builder.scope as the globals in exec call because
2522            # guard_fn.__globals__ becomes equal to builder.scope. This causes
2523            # guard_fn to hold a referece to f_locals sitting in builder.scope["L"]
2524            try:
2525                exec(pycode, globals_for_guard_fn, out)
2526            except SyntaxError as ex:
2527                log.exception("Failed to exec guard at line %s.\n%s", ex.lineno, pycode)
2528                raise
2529            guard_fn = out["___make_guard_fn"](*closure_vars.values())
2530
2531        guard_fn.closure_vars = closure_vars
2532        # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both
2533        guard_fn.args = largs
2534        if config.enable_cpp_guard_manager:
2535            guard_fn.populate_code_parts_for_debugging()
2536        else:
2537            guard_fn.code_parts = code_parts
2538        guard_fn.verbose_code_parts = verbose_code_parts
2539        # Grab only G, but preserve "G" because guards access it as "G"
2540        guard_fn.global_scope = globals_for_guard_fn
2541        guard_fn.guard_fail_fn = guard_fail_fn
2542        # will be populated by a non-owning reference to CacheEntry/ExtraState
2543        # when the CacheEntry is constructed
2544        guard_fn.cache_entry = None
2545        guard_fn.extra_state = None
2546        guard_fn.no_tensor_aliasing_sources = tensor_check_names
2547        return guard_fn
2548
2549    def invalidate(self):
2550        # Some tests reveal that CheckFunctionManager has no attribute
2551        # check_fn, but this case should not be of any concern.
2552        # This case doesn't seem easy to repro.
2553        if (
2554            hasattr(self, "check_fn")
2555            and self.check_fn is not DeletedGuardFn
2556            and (cache_entry := self.check_fn.cache_entry) is not None
2557            and (extra_state := self.check_fn.extra_state) is not None
2558        ):
2559            assert isinstance(cache_entry, CacheEntry)
2560            assert isinstance(extra_state, ExtraState)
2561            extra_state.invalidate(cache_entry)
2562            self.check_fn.cache_entry = None
2563            self.check_fn.extra_state = None
2564            self.check_fn = DeletedGuardFn
2565
2566    def id_ref(self, obj):
2567        """add a weakref, return the id"""
2568        try:
2569            if id(obj) not in self._weakrefs:
2570                # We will clear the _weakrefs dict at the end of __init__
2571                # function, which will delete the callbacks as well. Therefore,
2572                # we are using a finalizer which is kept alive.
2573                self._weakrefs[id(obj)] = weakref.ref(obj)
2574                weakref.finalize(obj, self.invalidate)
2575        except TypeError:
2576            pass  # cannot weakref bool object
2577        return id(obj)
2578
2579    def lookup_weakrefs(self, obj):
2580        """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects"""
2581        if id(obj) in self._weakrefs:
2582            return self._weakrefs[id(obj)]
2583        return None
2584
2585
2586def build_guard_function(code_parts, closure_args) -> Tuple[str, str]:
2587    from torch._inductor.utils import IndentedBuffer
2588
2589    if HAS_UNPARSE_FUNCTIONS:
2590        csepass = PyExprCSEPass()
2591        csepass.count(code_parts)
2592
2593        def replace(expr: str) -> Tuple[List[str], str]:
2594            return csepass.replace(expr)
2595
2596    else:
2597
2598        def replace(expr: str) -> Tuple[List[str], str]:
2599            return [], expr
2600
2601    # Generate the inner body of the guard function.
2602    # i.e. if-chain of the guard expressions.
2603    guard_body = IndentedBuffer()
2604    for expr in code_parts:
2605        preface, expr = replace(expr)
2606        guard_body.writelines(preface)
2607        guard_body.writeline(f"if not ({expr}):")
2608        with guard_body.indent():
2609            guard_body.writeline("return False")
2610
2611    # Wrap the inner body into the actual guard function.
2612    guard = IndentedBuffer()
2613    guard.writeline("def guard(L):")
2614    with guard.indent():
2615        guard.splice(guard_body)
2616        guard.writeline("return True")
2617
2618    # Wrap the whole guard function into another function
2619    # with the closure variables.
2620    make_guard_fn = IndentedBuffer()
2621    make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):")
2622    with make_guard_fn.indent():
2623        make_guard_fn.splice(guard)
2624        make_guard_fn.writeline("return guard")
2625
2626    return guard_body.getvalue(), make_guard_fn.getvalue()
2627
2628
2629def is_recompiles_enabled():
2630    return torch._logging._internal.log_state.is_artifact_enabled("recompiles")
2631
2632
2633def is_recompiles_verbose_enabled():
2634    return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose")
2635
2636
2637# this will only be used if cpp guards are disabled
2638def make_torch_function_mode_stack_guard(intial_stack):
2639    types = [type(x) for x in intial_stack]
2640    from .variables.torch_function import IGNORED_MODES
2641
2642    def check_torch_function_mode_stack():
2643        cur_stack = get_torch_function_mode_stack()
2644        if len(cur_stack) != len(types):
2645            return False
2646
2647        for ty, mode in zip(types, cur_stack):
2648            if ty in IGNORED_MODES:
2649                continue
2650            if ty != type(mode):
2651                return False
2652
2653        return True
2654
2655    return check_torch_function_mode_stack
2656
2657
2658def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
2659    duplicate_tensors = []
2660    global_scope = dict(guard_manager.global_scope)
2661    ids_to_source = collections.defaultdict(list)
2662    for tensor_source in guard_manager.no_tensor_aliasing_sources:  # type: ignore[attr-defined]
2663        global_scope["__compile_source__"] = tensor_source
2664        tensor_id = id(eval(tensor_source, global_scope, scope))
2665        ids_to_source[tensor_id].append(tensor_source)
2666
2667    for key in ids_to_source:
2668        if len(ids_to_source[key]) > 1:
2669            duplicate_tensors.append(f"{ids_to_source[key]}")
2670
2671    reason = ", ".join(duplicate_tensors)
2672    return [f"Duplicate tensors found: {reason}"]
2673
2674
2675def get_guard_fail_reason_helper(
2676    guard_fn: GuardFn,
2677    f_locals: Dict[str, object],
2678    compile_id: CompileId,
2679) -> str:
2680    """
2681    Return the reason why `guard_fn` failed.
2682    Updates `guard_failures` with the generated reason.
2683    Only the first failed check of guard_fn is reported.
2684    """
2685    scope = {"L": f_locals, "G": guard_fn.global_scope["G"]}
2686    scope.update(guard_fn.closure_vars)
2687    reasons: List[str] = []
2688
2689    no_tensor_aliasing_check_failed = False
2690
2691    verbose_code_parts: List[str] = []
2692    if config.enable_cpp_guard_manager:
2693        guard_manager = guard_fn
2694        guard_debug_info = guard_manager.check_verbose(f_locals)  # type: ignore[attr-defined]
2695        # For test_export_with_map_cond, the check_verbose fail even without the
2696        # C++ guard manager. We need to fix the issue to remove the comment.
2697        # assert not guard_debug_info.result
2698        if not guard_debug_info.result:
2699            verbose_code_parts = guard_debug_info.verbose_code_parts
2700            # verbose_code_parts is either the actual reason (e.g. in case of
2701            # TENSOR_MATCH) or it could be a list of verbose_code_part that we
2702            # passed to the leaf guard at construction time. If its a list, we
2703            # walk through this list and find the guard that failed. This is
2704            # very important for symbolic shape guards which are currently
2705            # installed as a lambda guard and can encompass a long list of code_parts.
2706
2707            if len(verbose_code_parts) == 1:
2708                if "Duplicate tensor found" in verbose_code_parts[0]:
2709                    no_tensor_aliasing_check_failed = True
2710                else:
2711                    reasons = verbose_code_parts
2712                    verbose_code_parts = []
2713    else:
2714        verbose_code_parts = guard_fn.verbose_code_parts
2715        # This is not needed for CPP guard because the verbose check is already
2716        # run in C++.
2717        scope["___check_tensors"] = scope["___check_tensors_verbose"]
2718
2719    if no_tensor_aliasing_check_failed:
2720        reasons = recompilation_reason_for_no_tensor_aliasing_guard(guard_fn, scope)
2721    else:
2722        for part in verbose_code_parts:
2723            global_scope = dict(guard_fn.global_scope)
2724            global_scope["__compile_source__"] = part
2725            with report_compile_source_on_error():
2726                try:
2727                    fail_reason = eval(part, global_scope, scope)
2728                except Exception as e:
2729                    if is_recompiles_verbose_enabled():
2730                        continue
2731                    else:
2732                        raise
2733            # Only ___check_tensors knows how to return a fancy fail reason;
2734            # for everything else we just report the code that failed
2735
2736            if isinstance(fail_reason, bool) and not fail_reason:
2737                fail_reason = part
2738            if isinstance(fail_reason, str):
2739                reasons.append(fail_reason)
2740                if not is_recompiles_verbose_enabled():
2741                    break
2742
2743    reason_str = f"{compile_id}: " + "; ".join(reasons)
2744    return reason_str
2745
2746
2747def get_guard_fail_reason(
2748    guard_fn: GuardFn,
2749    code: types.CodeType,
2750    f_locals: Dict[str, object],
2751    compile_id: CompileId,
2752) -> str:
2753    reason_str = get_guard_fail_reason_helper(guard_fn, f_locals, compile_id)
2754    guard_failures[orig_code_map[code]].append(reason_str)
2755
2756    try:
2757        if guard_fn.guard_fail_fn is not None:
2758            guard_fn.guard_fail_fn(
2759                GuardFail(reason_str or "unknown reason", orig_code_map[code])
2760            )
2761    except Exception as e:
2762        log.exception(
2763            "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval",
2764        )
2765
2766    return reason_str
2767
2768
2769def get_and_maybe_log_recompilation_reason(
2770    cache_entry, frame: types.FrameType
2771) -> List[str]:
2772    """
2773    Return the list of guard failure reasons using cache_entry.
2774    Logs the recompilation reason if `recompiles` logging is enabled.
2775    Raises a RecompileError if `config.error_on_recompile` is enabled.
2776    """
2777    reasons = []
2778    while cache_entry is not None:
2779        reason = get_guard_fail_reason(
2780            cache_entry.check_fn,
2781            cache_entry.code,
2782            frame.f_locals,
2783            cache_entry.compile_id,
2784        )
2785        if reason:
2786            reasons.append(reason)
2787        cache_entry = cache_entry.next
2788
2789    code = frame.f_code
2790
2791    # at least one of "recompiles" or "recompiles_verbose" is enabled
2792    do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled()
2793
2794    if do_recompiles_log or config.error_on_recompile:
2795        if is_recompiles_verbose_enabled():
2796            failures = "\n\n".join(
2797                f"guard {i} failures:\n" + textwrap.indent(reason, "- ")
2798                for i, reason in enumerate(reasons)
2799            )
2800        else:
2801            failures = textwrap.indent("\n".join(reasons), "- ")
2802        guard_failure_details = (
2803            f"triggered by the following guard failure(s):\n{failures}"
2804        )
2805        message = (
2806            f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n"
2807            f"{textwrap.indent(guard_failure_details, '    ')}"
2808        )
2809        if do_recompiles_log:
2810            if is_recompiles_verbose_enabled():
2811                recompiles_verbose_log.debug(message)
2812            else:
2813                recompiles_log.debug(message)
2814        if config.error_on_recompile:
2815            raise exc.RecompileError(message)
2816
2817    torch._logging.trace_structured(
2818        "artifact",
2819        metadata_fn=lambda: {
2820            "name": "recompile_reasons",
2821            "encoding": "json",
2822        },
2823        payload_fn=lambda: reasons,
2824    )
2825
2826    return reasons
2827
2828
2829def guard_error_hook(
2830    guard_fn: GuardFn,
2831    code: types.CodeType,
2832    f_locals: Dict[str, object],
2833    index: int,
2834    last: bool,
2835):
2836    print(
2837        f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}"
2838    )
2839    print("lambda " + ", ".join(guard_fn.args) + ":")
2840    print(" ", " and\n  ".join(guard_fn.code_parts))
2841
2842    if config.enable_cpp_guard_manager:
2843        print(guard_fn)
2844
2845    local_scope = {"L": f_locals, **guard_fn.closure_vars}
2846    for guard in guard_fn.code_parts:
2847        try:
2848            eval(guard, guard_fn.global_scope, local_scope)
2849        except:  # noqa: B001,E722
2850            print(f"Malformed guard:\n{guard}")
2851
2852
2853set_guard_error_hook(guard_error_hook)
2854
2855
2856def unique(seq):
2857    seen = set()
2858    for x in seq:
2859        if x not in seen:
2860            yield x
2861            seen.add(x)
2862
2863
2864def make_dupe_guard(obj_source, dupe_source):
2865    # Note - we may end up in a situation where we invoke something like
2866    # def fn(x, y)
2867    # with fn(x, x)
2868    # Prior to the addition of tracking to all relevant objects, we would handle this just fine by
2869    # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However,
2870    # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here -
2871    # In the fn(x, x) example call above look like a graph with a single input.
2872    # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard.
2873
2874    # Note - we may not have a source, that is fine, it just means we had an object that is safe to have
2875    # leave unsourced - like a local list created and discharged entirely within a local scope.
2876    if dupe_source and dupe_source != obj_source:
2877        ser_source_is_local = is_from_local_source(dupe_source)
2878        source_is_local = is_from_local_source(obj_source)
2879        if is_from_flatten_script_object_source(
2880            dupe_source
2881        ) or is_from_flatten_script_object_source(obj_source):
2882            raise exc.UnsafeScriptObjectError(
2883                f"{obj_source.name()} is alising {dupe_source.name()}. This is not supported."
2884                f" Please do a clone for corresponding input."
2885            )
2886
2887        # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
2888        # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
2889        # so maybe we should do this refactor before we land this...
2890        # TODO(voz): Combine local and global guard builders.
2891        if ser_source_is_local == source_is_local:
2892            # Note - this is a little aggressive - these being duplicate input does not always matter.
2893            # However, this should always be a sound guard to add here.
2894            return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
2895    return None
2896
2897
2898def install_guard(*guards, skip=0):
2899    """
2900    Add dynamo guards to the current tracing context.
2901
2902    Args:
2903        guards: guard(s) to add
2904        skip: number of stack frames to ignore for debug stack trace
2905    """
2906    from torch._guards import TracingContext
2907
2908    collect_debug_stack = guards_log.isEnabledFor(
2909        logging.DEBUG
2910    ) or verbose_guards_log.isEnabledFor(logging.DEBUG)
2911    add = TracingContext.get().guards_context.dynamo_guards.add
2912    for guard in guards:
2913        assert isinstance(guard, Guard)
2914        add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1)
2915