1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python memory leak detection utility. 16 17Please don't use this class directly. Instead, use `MemoryChecker` wrapper. 18""" 19 20import collections 21import copy 22import gc 23 24from tensorflow.python.framework import _python_memory_checker_helper 25from tensorflow.python.platform import tf_logging as logging 26from tensorflow.python.profiler import trace 27 28 29def _get_typename(obj): 30 """Return human readable pretty type name string.""" 31 objtype = type(obj) 32 name = objtype.__name__ 33 module = getattr(objtype, '__module__', None) 34 if module: 35 return '{}.{}'.format(module, name) 36 else: 37 return name 38 39 40def _create_python_object_snapshot(): 41 gc.collect() 42 all_objects = gc.get_objects() 43 result = collections.defaultdict(set) 44 for obj in all_objects: 45 result[_get_typename(obj)].add(id(obj)) 46 return result 47 48 49def _snapshot_diff(old_snapshot, new_snapshot, exclude_ids): 50 result = collections.Counter() 51 for new_name, new_ids in new_snapshot.items(): 52 old_ids = old_snapshot[new_name] 53 result[new_name] = len(new_ids - exclude_ids) - len(old_ids - exclude_ids) 54 55 # This removes zero or negative value entries. 56 result += collections.Counter() 57 return result 58 59 60class _PythonMemoryChecker(object): 61 """Python memory leak detection class.""" 62 63 def __init__(self): 64 self._snapshots = [] 65 # cache the function used by mark_stack_trace_and_call to avoid 66 # contaminating the leak measurement. 67 def _record_snapshot(): 68 self._snapshots.append(_create_python_object_snapshot()) 69 70 self._record_snapshot = _record_snapshot 71 72 # We do not enable trace_wrapper on this function to avoid contaminating 73 # the snapshot. 74 def record_snapshot(self): 75 # Function called using `mark_stack_trace_and_call` will have 76 # "_python_memory_checker_helper" string in the C++ stack trace. This will 77 # be used to filter out C++ memory allocations caused by this function, 78 # because we are not interested in detecting memory growth caused by memory 79 # checker itself. 80 _python_memory_checker_helper.mark_stack_trace_and_call( 81 self._record_snapshot) 82 83 @trace.trace_wrapper 84 def report(self): 85 # TODO(kkb): Implement. 86 pass 87 88 @trace.trace_wrapper 89 def assert_no_leak_if_all_possibly_except_one(self): 90 """Raises an exception if a leak is detected. 91 92 This algorithm classifies a series of allocations as a leak if it's the same 93 type at every snapshot, but possibly except one snapshot. 94 """ 95 96 snapshot_diffs = [] 97 for i in range(0, len(self._snapshots) - 1): 98 snapshot_diffs.append(self._snapshot_diff(i, i + 1)) 99 100 allocation_counter = collections.Counter() 101 for diff in snapshot_diffs: 102 for name, count in diff.items(): 103 if count > 0: 104 allocation_counter[name] += 1 105 106 leaking_object_names = { 107 name for name, count in allocation_counter.items() 108 if count >= len(snapshot_diffs) - 1 109 } 110 111 if leaking_object_names: 112 object_list_to_print = '\n'.join( 113 [' - ' + name for name in leaking_object_names]) 114 raise AssertionError( 115 'These Python objects were allocated in every snapshot possibly ' 116 f'except one.\n\n{object_list_to_print}') 117 118 @trace.trace_wrapper 119 def assert_no_new_objects(self, threshold=None): 120 """Assert no new Python objects.""" 121 122 if not threshold: 123 threshold = {} 124 125 count_diff = self._snapshot_diff(0, -1) 126 original_count_diff = copy.deepcopy(count_diff) 127 count_diff.subtract(collections.Counter(threshold)) 128 129 if max(count_diff.values() or [0]) > 0: 130 raise AssertionError('New Python objects created exceeded the threshold.' 131 '\nPython object threshold:\n' 132 f'{threshold}\n\nNew Python objects:\n' 133 f'{original_count_diff.most_common()}') 134 elif min(count_diff.values(), default=0) < 0: 135 logging.warning('New Python objects created were less than the threshold.' 136 '\nPython object threshold:\n' 137 f'{threshold}\n\nNew Python objects:\n' 138 f'{original_count_diff.most_common()}') 139 140 @trace.trace_wrapper 141 def _snapshot_diff(self, old_index, new_index): 142 return _snapshot_diff(self._snapshots[old_index], 143 self._snapshots[new_index], 144 self._get_internal_object_ids()) 145 146 @trace.trace_wrapper 147 def _get_internal_object_ids(self): 148 ids = set() 149 for snapshot in self._snapshots: 150 ids.add(id(snapshot)) 151 for v in snapshot.values(): 152 ids.add(id(v)) 153 return ids 154