xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/memory_checker.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Memory leak detection utility."""
16
17from tensorflow.python.framework.python_memory_checker import _PythonMemoryChecker
18from tensorflow.python.profiler import trace
19from tensorflow.python.util import tf_inspect
20
21try:
22  from tensorflow.python.platform.cpp_memory_checker import _CppMemoryChecker as CppMemoryChecker  # pylint:disable=g-import-not-at-top
23except ImportError:
24  CppMemoryChecker = None
25
26
27def _get_test_name_best_effort():
28  """If available, return the current test name. Otherwise, `None`."""
29  for stack in tf_inspect.stack():
30    function_name = stack[3]
31    if function_name.startswith('test'):
32      try:
33        class_name = stack[0].f_locals['self'].__class__.__name__
34        return class_name + '.' + function_name
35      except:  # pylint:disable=bare-except
36        pass
37
38  return None
39
40
41# TODO(kkb): Also create decorator versions for convenience.
42class MemoryChecker(object):
43  """Memory leak detection class.
44
45  This is a utility class to detect Python and C++ memory leaks. It's intended
46  for both testing and debugging. Basic usage:
47
48  >>> # MemoryChecker() context manager tracks memory status inside its scope.
49  >>> with MemoryChecker() as memory_checker:
50  >>>   tensors = []
51  >>>   for _ in range(10):
52  >>>     # Simulating `tf.constant(1)` object leak every iteration.
53  >>>     tensors.append(tf.constant(1))
54  >>>
55  >>>     # Take a memory snapshot for later analysis.
56  >>>     memory_checker.record_snapshot()
57  >>>
58  >>> # `report()` generates a html graph file showing allocations over
59  >>> # snapshots per every stack trace.
60  >>> memory_checker.report()
61  >>>
62  >>> # This assertion will detect `tf.constant(1)` object leak.
63  >>> memory_checker.assert_no_leak_if_all_possibly_except_one()
64
65  `record_snapshot()` must be called once every iteration at the same location.
66  This is because the detection algorithm relies on the assumption that if there
67  is a leak, it's happening similarly on every snapshot.
68  """
69
70  @trace.trace_wrapper
71  def __enter__(self):
72    self._python_memory_checker = _PythonMemoryChecker()
73    if CppMemoryChecker:
74      self._cpp_memory_checker = CppMemoryChecker(_get_test_name_best_effort())
75    return self
76
77  @trace.trace_wrapper
78  def __exit__(self, exc_type, exc_value, traceback):
79    if CppMemoryChecker:
80      self._cpp_memory_checker.stop()
81
82  # We do not enable trace_wrapper on this function to avoid contaminating
83  # the snapshot.
84  def record_snapshot(self):
85    """Take a memory snapshot for later analysis.
86
87    `record_snapshot()` must be called once every iteration at the same
88    location. This is because the detection algorithm relies on the assumption
89    that if there is a leak, it's happening similarly on every snapshot.
90
91    The recommended number of `record_snapshot()` call depends on the testing
92    code complexity and the allcoation pattern.
93    """
94    self._python_memory_checker.record_snapshot()
95    if CppMemoryChecker:
96      self._cpp_memory_checker.record_snapshot()
97
98  @trace.trace_wrapper
99  def report(self):
100    """Generates a html graph file showing allocations over snapshots.
101
102    It create a temporary directory and put all the output files there.
103    If this is running under Google internal testing infra, it will use the
104    directory provided the infra instead.
105    """
106    self._python_memory_checker.report()
107    if CppMemoryChecker:
108      self._cpp_memory_checker.report()
109
110  @trace.trace_wrapper
111  def assert_no_leak_if_all_possibly_except_one(self):
112    """Raises an exception if a leak is detected.
113
114    This algorithm classifies a series of allocations as a leak if it's the same
115    type(Python) orit happens at the same stack trace(C++) at every snapshot,
116    but possibly except one snapshot.
117    """
118
119    self._python_memory_checker.assert_no_leak_if_all_possibly_except_one()
120    if CppMemoryChecker:
121      self._cpp_memory_checker.assert_no_leak_if_all_possibly_except_one()
122
123  @trace.trace_wrapper
124  def assert_no_new_python_objects(self, threshold=None):
125    """Raises an exception if there are new Python objects created.
126
127    It computes the number of new Python objects per type using the first and
128    the last snapshots.
129
130    Args:
131      threshold: A dictionary of [Type name string], [count] pair. It won't
132        raise an exception if the new Python objects are under this threshold.
133    """
134    self._python_memory_checker.assert_no_new_objects(threshold=threshold)
135