1# mypy: allow-untyped-defs 2import logging 3import types 4import weakref 5from dataclasses import dataclass 6from typing import Tuple 7 8from torch._guards import CompileId 9 10from . import config 11 12 13log = logging.getLogger(__name__) 14""" 15[Note on cache size limit] 16 17Background - TorchDynamo cache is a linked list. Each cache entry is a 18(check_fn, out_code, next pointer). These are stored on the f_code's co_extra 19scratch space. When a frame is invoked, we walk this linked list and run 20check_fn in each cache_entry to decide if the frame needs recompilation. If none 21of the check_fn's returns True, we recompile and add a new entry. To ensure we 22don't end up recompiling infinitely, we put limits on the cache size. 23 24There are two limits 251) cache_size_limit 262) accumulated_cache_size_limit 27 28 29Earlier we used to have only limit - maximum number of entries in 1 cache line 30(which is now represented by (2) above). So, why do we need two limits? Lets try 31to understand that. 32 33In general, we want our cache limit value to be a small number (e.g. 8 or even 34lower). This ensures that for frames that cause too many recompilation fall to 35eager quickly. However, there is another problem that prevents us from lowering 36the value of cache_size_limit. This is due to ID_MATCH'd guards. Today, we put 37ID_MATCH guards on nn module if there is a graph break. This means we will have 38many recompilations for the same code object because the ID_MATCH guard fails 39for different instances of the nn module. This is a common pattern in how models 40are authored. Therefore, this requires us to keep the cache_size_limit high. 41 42We resolve this by introducing these two limits. The first limit (1) limits the 43number of cache entries that have an ID_MATCH'd guard for an nn module instance. 44And, (2)nd limit becomes a safeguard mechanism to have a maximum compilations 45for a code object. One important question is - what is the limit for the code 46object that does not have any ID_MATCH guard? For such code objects, we choose 47(1) as the cache size limit. 48 49Lets take an example to understand how these limits help. Suppose, we have 16 50instances of a nn module and we ID_MATCH on the self object. Further, suppose 51the inputs to these functions have varying batch size, leading to one 52recompilation. In total, there will be 32 recompilations, and therefore 32 cache 53entries on the forward code object. In the older case when we had only 1 limit, 54our cache size limit must be >= 32 to capture all these recompilations. Now, 55suppose there is a separate function in the same program which is very dynamic 56and unsuitable for compilation. Such a function will need to undergo 32 57compilations to burst the cache and fallback to eager. These 32 recompilations 58are too many and we want to fallback for these compilation-unfriendly functions 59sooner. 60 61In the new scenario, we can have (1) cache_size_limit = 2, (2) 62accumulated_cache_size_limit = 32. This means that each ID_MATCH'd object can 63have maximum of two cache entries, and the maximum number of cache entries 64(irrespective of ID_MATCH obj) is 32. This covers the case of forward code 65object which has 32 recompilations. For the other function, the one unsuitable 66for recompilation, our limit is 2. So, we will burst the cache in just 2 67recompilations. In this manner, these 2 limits help us resolve the tension 68mentioned earlier. 69""" 70 71 72@dataclass 73class CacheSizeRelevantForFrame: 74 """ 75 We track the number of cache entries that have same id_match objects as the 76 given frame. 77 78 TODO(janimesh) - Consider adding a map from tuple_of_match_ids to count - 79 https://github.com/pytorch/pytorch/pull/107496#discussion_r1304564682 - this 80 could be useful for debugging as well. 81 """ 82 83 # Total number of CacheEntry objects in the Dynamo linked list 84 num_cache_entries: int = 0 85 86 # Number of CacheEntry objects having same ID_MATCH'd objects as given frame. 87 num_cache_entries_with_same_id_matched_objs: int = 0 88 89 def will_compilation_exceed(self, limit: int) -> bool: 90 # Checks if a compilation will exceed the given limit (thats why >=). 91 return ( 92 self.will_compilation_exceed_accumulated_limit() 93 or self.will_compilation_exceed_specific_limit(limit) 94 ) 95 96 def will_compilation_exceed_accumulated_limit(self) -> bool: 97 return self.num_cache_entries >= config.accumulated_cache_size_limit 98 99 def will_compilation_exceed_specific_limit(self, limit: int) -> bool: 100 return self.num_cache_entries_with_same_id_matched_objs >= limit 101 102 103def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str): 104 obj = frame.f_locals.get(local_name, None) 105 weak_id = None 106 try: 107 weak_id = weakref.ref(obj) 108 except TypeError: 109 pass # cannot weakref bool object 110 return weak_id 111 112 113def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool: 114 """ 115 Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones 116 in frame.f_locals. 117 """ 118 if not cache_entry: 119 return False 120 121 for ( 122 local_name, 123 weakref_from_cache_entry, 124 ) in cache_entry.check_fn.id_matched_objs.items(): 125 if weakref_from_cache_entry() is not None: 126 weakref_from_frame = _get_weakref_from_f_locals(frame, local_name) 127 if weakref_from_frame != weakref_from_cache_entry: 128 return False 129 130 # Also covers the case where no ID_MATCH objects are saved in frame.f_locals 131 return True 132 133 134def compute_cache_size( 135 frame: types.FrameType, cache_entry 136) -> CacheSizeRelevantForFrame: 137 # Walk the linked list to calculate the cache size 138 num_cache_entries = 0 139 num_cache_entries_with_same_id_matched_objs = 0 140 141 while cache_entry: 142 num_cache_entries += 1 143 # Track the number of cache entries having same ID_MATCH'd objects as 144 # that of frame.f_locals. This will be used later to compare against the 145 # cache_size_limit. 146 if _has_same_id_matched_objs(frame, cache_entry): 147 num_cache_entries_with_same_id_matched_objs += 1 148 cache_entry = cache_entry.next 149 150 return CacheSizeRelevantForFrame( 151 num_cache_entries, num_cache_entries_with_same_id_matched_objs 152 ) 153 154 155def is_recompilation(cache_size: CacheSizeRelevantForFrame) -> bool: 156 """ 157 If the frame (earlier parsed by compute_cache_size) has more than 1 cache 158 entry with same ID_MATCH'd objects, then its a recompilation. 159 """ 160 # Note that you can have multiple entries in the cache but still not a 161 # recompile, e.g., you can have 64 nn module instances, each one having an 162 # ID_MATCH guard, and each one having just 1 cache entry in the cache. In 163 # this case, we can have 64 entries in the cache, but no recompilation 164 # because there is only one entry for each id_matched_obj. 165 return cache_size.will_compilation_exceed(1) 166 167 168def exceeds_cache_size_limit( 169 cache_size: CacheSizeRelevantForFrame, compile_id: CompileId 170) -> Tuple[bool, str]: 171 """ 172 Checks if we are exceeding the cache size limit. 173 """ 174 if cache_size.will_compilation_exceed_accumulated_limit(): 175 return True, "accumulated_cache_size_limit" 176 if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit): 177 return True, "cache_size_limit" 178 # NOTE this check is needed in the case that the frame's cache doesn't grow 179 # and we keep recompiling. This can happen if the guard check_fn becomes invalidated, 180 # e.g. due to guarded objects being freed. This technically makes the 181 # will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the 182 # check in case we have a better fix in the future. 183 if compile_id.frame_compile_id >= config.accumulated_cache_size_limit: 184 return True, "accumulated_cache_size_limit" 185 return False, "" 186