xref: /aosp_15_r20/external/pytorch/torch/_dynamo/cache_size.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import types
4import weakref
5from dataclasses import dataclass
6from typing import Tuple
7
8from torch._guards import CompileId
9
10from . import config
11
12
13log = logging.getLogger(__name__)
14"""
15[Note on cache size limit]
16
17Background - TorchDynamo cache is a linked list. Each cache entry is a
18(check_fn, out_code, next pointer). These are stored on the f_code's co_extra
19scratch space. When a frame is invoked, we walk this linked list and run
20check_fn in each cache_entry to decide if the frame needs recompilation. If none
21of the check_fn's returns True, we recompile and add a new entry. To ensure we
22don't end up recompiling infinitely, we put limits on the cache size.
23
24There are two limits
251) cache_size_limit
262) accumulated_cache_size_limit
27
28
29Earlier we used to have only limit - maximum number of entries in 1 cache line
30(which is now represented by (2) above). So, why do we need two limits? Lets try
31to understand that.
32
33In general, we want our cache limit value to be a small number (e.g. 8 or even
34lower). This ensures that for frames that cause too many recompilation fall to
35eager quickly. However, there is another problem that prevents us from lowering
36the value of cache_size_limit. This is due to ID_MATCH'd guards. Today, we put
37ID_MATCH guards on nn module if there is a graph break. This means we will have
38many recompilations for the same code object because the ID_MATCH guard fails
39for different instances of the nn module. This is a common pattern in how models
40are authored. Therefore, this requires us to keep the cache_size_limit high.
41
42We resolve this by introducing these two limits. The first limit (1) limits the
43number of cache entries that have an ID_MATCH'd guard for an nn module instance.
44And, (2)nd limit becomes a safeguard mechanism to have a maximum compilations
45for a code object. One important question is - what is the limit for the code
46object that does not have any ID_MATCH guard? For such code objects, we choose
47(1) as the cache size limit.
48
49Lets take an example to understand how these limits help. Suppose, we have 16
50instances of a nn module and we ID_MATCH on the self object. Further, suppose
51the inputs to these functions have varying batch size, leading to one
52recompilation. In total, there will be 32 recompilations, and therefore 32 cache
53entries on the forward code object. In the older case when we had only 1 limit,
54our cache size limit must be >= 32 to capture all these recompilations. Now,
55suppose there is a separate function in the same program which is very dynamic
56and unsuitable for compilation. Such a function will need to undergo 32
57compilations to burst the cache and fallback to eager. These 32 recompilations
58are too many and we want to fallback for these compilation-unfriendly functions
59sooner.
60
61In the new scenario, we can have (1) cache_size_limit = 2, (2)
62accumulated_cache_size_limit = 32. This means that each ID_MATCH'd object can
63have maximum of two cache entries, and the maximum number of cache entries
64(irrespective of ID_MATCH obj) is 32. This covers the case of forward code
65object which has 32 recompilations. For the other function, the one unsuitable
66for recompilation, our limit is 2. So, we will burst the cache in just 2
67recompilations. In this manner, these 2 limits help us resolve the tension
68mentioned earlier.
69"""
70
71
72@dataclass
73class CacheSizeRelevantForFrame:
74    """
75    We track the number of cache entries that have same id_match objects as the
76    given frame.
77
78    TODO(janimesh) - Consider adding a map from tuple_of_match_ids to count -
79    https://github.com/pytorch/pytorch/pull/107496#discussion_r1304564682 - this
80    could be useful for debugging as well.
81    """
82
83    # Total number of CacheEntry objects in the Dynamo linked list
84    num_cache_entries: int = 0
85
86    # Number of CacheEntry objects having same ID_MATCH'd objects as given frame.
87    num_cache_entries_with_same_id_matched_objs: int = 0
88
89    def will_compilation_exceed(self, limit: int) -> bool:
90        # Checks if a compilation will exceed the given limit (thats why >=).
91        return (
92            self.will_compilation_exceed_accumulated_limit()
93            or self.will_compilation_exceed_specific_limit(limit)
94        )
95
96    def will_compilation_exceed_accumulated_limit(self) -> bool:
97        return self.num_cache_entries >= config.accumulated_cache_size_limit
98
99    def will_compilation_exceed_specific_limit(self, limit: int) -> bool:
100        return self.num_cache_entries_with_same_id_matched_objs >= limit
101
102
103def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str):
104    obj = frame.f_locals.get(local_name, None)
105    weak_id = None
106    try:
107        weak_id = weakref.ref(obj)
108    except TypeError:
109        pass  # cannot weakref bool object
110    return weak_id
111
112
113def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
114    """
115    Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones
116    in frame.f_locals.
117    """
118    if not cache_entry:
119        return False
120
121    for (
122        local_name,
123        weakref_from_cache_entry,
124    ) in cache_entry.check_fn.id_matched_objs.items():
125        if weakref_from_cache_entry() is not None:
126            weakref_from_frame = _get_weakref_from_f_locals(frame, local_name)
127            if weakref_from_frame != weakref_from_cache_entry:
128                return False
129
130    # Also covers the case where no ID_MATCH objects are saved in frame.f_locals
131    return True
132
133
134def compute_cache_size(
135    frame: types.FrameType, cache_entry
136) -> CacheSizeRelevantForFrame:
137    # Walk the linked list to calculate the cache size
138    num_cache_entries = 0
139    num_cache_entries_with_same_id_matched_objs = 0
140
141    while cache_entry:
142        num_cache_entries += 1
143        # Track the number of cache entries having same ID_MATCH'd objects as
144        # that of frame.f_locals. This will be used later to compare against the
145        # cache_size_limit.
146        if _has_same_id_matched_objs(frame, cache_entry):
147            num_cache_entries_with_same_id_matched_objs += 1
148        cache_entry = cache_entry.next
149
150    return CacheSizeRelevantForFrame(
151        num_cache_entries, num_cache_entries_with_same_id_matched_objs
152    )
153
154
155def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool:
156    """
157    If the frame (earlier parsed by compute_cache_size) has more than 1 cache
158    entry with same ID_MATCH'd objects, then its a recompilation.
159    """
160    # Note that you can have multiple entries in the cache but still not a
161    # recompile, e.g., you can have 64 nn module instances, each one having an
162    # ID_MATCH guard, and each one having just 1 cache entry in the cache.  In
163    # this case, we can have 64 entries in the cache, but no recompilation
164    # because there is only one entry for each id_matched_obj.
165    return cache_size.will_compilation_exceed(1)
166
167
168def exceeds_cache_size_limit(
169    cache_size: CacheSizeRelevantForFrame, compile_id: CompileId
170) -> Tuple[bool, str]:
171    """
172    Checks if we are exceeding the cache size limit.
173    """
174    if cache_size.will_compilation_exceed_accumulated_limit():
175        return True, "accumulated_cache_size_limit"
176    if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit):
177        return True, "cache_size_limit"
178    # NOTE this check is needed in the case that the frame's cache doesn't grow
179    # and we keep recompiling. This can happen if the guard check_fn becomes invalidated,
180    # e.g. due to guarded objects being freed. This technically makes the
181    # will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the
182    # check in case we have a better fix in the future.
183    if compile_id.frame_compile_id >= config.accumulated_cache_size_limit:
184        return True, "accumulated_cache_size_limit"
185    return False, ""
186