xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_memory_checker.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python memory leak detection utility.
16
17Please don't use this class directly.  Instead, use `MemoryChecker` wrapper.
18"""
19
20import collections
21import copy
22import gc
23
24from tensorflow.python.framework import _python_memory_checker_helper
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.profiler import trace
27
28
29def _get_typename(obj):
30  """Return human readable pretty type name string."""
31  objtype = type(obj)
32  name = objtype.__name__
33  module = getattr(objtype, '__module__', None)
34  if module:
35    return '{}.{}'.format(module, name)
36  else:
37    return name
38
39
40def _create_python_object_snapshot():
41  gc.collect()
42  all_objects = gc.get_objects()
43  result = collections.defaultdict(set)
44  for obj in all_objects:
45    result[_get_typename(obj)].add(id(obj))
46  return result
47
48
49def _snapshot_diff(old_snapshot, new_snapshot, exclude_ids):
50  result = collections.Counter()
51  for new_name, new_ids in new_snapshot.items():
52    old_ids = old_snapshot[new_name]
53    result[new_name] = len(new_ids - exclude_ids) - len(old_ids - exclude_ids)
54
55  # This removes zero or negative value entries.
56  result += collections.Counter()
57  return result
58
59
60class _PythonMemoryChecker(object):
61  """Python memory leak detection class."""
62
63  def __init__(self):
64    self._snapshots = []
65    # cache the function used by mark_stack_trace_and_call to avoid
66    # contaminating the leak measurement.
67    def _record_snapshot():
68      self._snapshots.append(_create_python_object_snapshot())
69
70    self._record_snapshot = _record_snapshot
71
72  # We do not enable trace_wrapper on this function to avoid contaminating
73  # the snapshot.
74  def record_snapshot(self):
75    # Function called using `mark_stack_trace_and_call` will have
76    # "_python_memory_checker_helper" string in the C++ stack trace.  This will
77    # be used to filter out C++ memory allocations caused by this function,
78    # because we are not interested in detecting memory growth caused by memory
79    # checker itself.
80    _python_memory_checker_helper.mark_stack_trace_and_call(
81        self._record_snapshot)
82
83  @trace.trace_wrapper
84  def report(self):
85    # TODO(kkb): Implement.
86    pass
87
88  @trace.trace_wrapper
89  def assert_no_leak_if_all_possibly_except_one(self):
90    """Raises an exception if a leak is detected.
91
92    This algorithm classifies a series of allocations as a leak if it's the same
93    type at every snapshot, but possibly except one snapshot.
94    """
95
96    snapshot_diffs = []
97    for i in range(0, len(self._snapshots) - 1):
98      snapshot_diffs.append(self._snapshot_diff(i, i + 1))
99
100    allocation_counter = collections.Counter()
101    for diff in snapshot_diffs:
102      for name, count in diff.items():
103        if count > 0:
104          allocation_counter[name] += 1
105
106    leaking_object_names = {
107        name for name, count in allocation_counter.items()
108        if count >= len(snapshot_diffs) - 1
109    }
110
111    if leaking_object_names:
112      object_list_to_print = '\n'.join(
113          [' - ' + name for name in leaking_object_names])
114      raise AssertionError(
115          'These Python objects were allocated in every snapshot possibly '
116          f'except one.\n\n{object_list_to_print}')
117
118  @trace.trace_wrapper
119  def assert_no_new_objects(self, threshold=None):
120    """Assert no new Python objects."""
121
122    if not threshold:
123      threshold = {}
124
125    count_diff = self._snapshot_diff(0, -1)
126    original_count_diff = copy.deepcopy(count_diff)
127    count_diff.subtract(collections.Counter(threshold))
128
129    if max(count_diff.values() or [0]) > 0:
130      raise AssertionError('New Python objects created exceeded the threshold.'
131                           '\nPython object threshold:\n'
132                           f'{threshold}\n\nNew Python objects:\n'
133                           f'{original_count_diff.most_common()}')
134    elif min(count_diff.values(), default=0) < 0:
135      logging.warning('New Python objects created were less than the threshold.'
136                      '\nPython object threshold:\n'
137                      f'{threshold}\n\nNew Python objects:\n'
138                      f'{original_count_diff.most_common()}')
139
140  @trace.trace_wrapper
141  def _snapshot_diff(self, old_index, new_index):
142    return _snapshot_diff(self._snapshots[old_index],
143                          self._snapshots[new_index],
144                          self._get_internal_object_ids())
145
146  @trace.trace_wrapper
147  def _get_internal_object_ids(self):
148    ids = set()
149    for snapshot in self._snapshots:
150      ids.add(id(snapshot))
151      for v in snapshot.values():
152        ids.add(id(v))
153    return ids
154