xref: /aosp_15_r20/external/tensorflow/tensorflow/python/util/tf_stack.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions used to extract and analyze stacks.  Faster than Python libs."""
16# pylint: disable=g-bad-name
17import collections
18import inspect
19import threading
20
21# TODO(b/138203821): change to from ...util import ... once the bug is fixed.
22from tensorflow.python.util import _tf_stack
23
24# Generally such lookups should be done using `threading.local()`. See
25# https://blogs.gnome.org/jamesh/2008/06/11/tls-python/ for a detailed
26# explanation of why. However the transform stacks are expected to be empty
27# when a thread is joined, so reusing the key does not introduce a correctness
28# issue. Moreover, get_ident is faster than storing and retrieving a unique
29# key in a thread local store.
30_get_thread_key = threading.get_ident
31
32
33# TODO(mdan): Move these to C++ as well.
34# Moving to C++ can further avoid extra copies made by get_effective_map.
35_source_mapper_stacks = collections.defaultdict(lambda: [SentinelMapper()])
36_source_filter_stacks = collections.defaultdict(lambda: [SentinelFilter()])
37
38
39class StackTraceTransform(object):
40  """Base class for stack trace transformation functions."""
41
42  _stack_dict = None  # Subclasses should override
43  _thread_key = None
44
45  def __enter__(self):
46    # Any given instance is assumed to be used by a single thread, which reduces
47    # expensive thread local lookups.
48    if self._thread_key is None:
49      self._thread_key = _get_thread_key()
50    else:
51      assert self._thread_key == _get_thread_key(), 'Shared across threads?'
52
53    stack = self._stack_dict[self._thread_key]
54    self.parent = stack[-1]
55    stack.append(self)
56    self.update()
57    return self
58
59  def __exit__(self, unused_type, unused_value, unused_traceback):
60    top = self._stack_dict[self._thread_key].pop()
61    assert top is self, 'Concurrent access?'
62
63  def update(self):
64    raise NotImplementedError('subclasses need to override this')
65
66
67class StackTraceMapper(StackTraceTransform):
68  """Allows remapping traceback information to different source code."""
69  _stack_dict = _source_mapper_stacks
70
71  def __init__(self):
72    self.internal_map = _tf_stack.PyBindSourceMap()
73
74  def update(self):
75    self.internal_map.update_to(tuple(self.get_effective_source_map().items()))
76
77  def get_effective_source_map(self):
78    """Returns a map (filename, lineno) -> (filename, lineno, function_name)."""
79    raise NotImplementedError('subclasses need to override this')
80
81
82EMPTY_DICT = {}
83
84
85class SentinelMapper(StackTraceMapper):
86
87  def get_effective_source_map(self):
88    return EMPTY_DICT
89
90
91class StackTraceFilter(StackTraceTransform):
92  """Allows filtering traceback information by removing superfluous frames."""
93  _stack_dict = _source_filter_stacks
94
95  def __init__(self):
96    self.internal_set = _tf_stack.PyBindFileSet()
97
98  def update(self):
99    self.internal_set.update_to(set(self.get_filtered_filenames()))
100
101  def get_filtered_filenames(self):
102    raise NotImplementedError('subclasses need to override this')
103
104
105EMPTY_SET = frozenset()
106
107
108class SentinelFilter(StackTraceFilter):
109
110  def get_filtered_filenames(self):
111    return EMPTY_SET
112
113
114class CurrentModuleFilter(StackTraceFilter):
115  """Filters stack frames from the module where this is used (best effort)."""
116
117  def __init__(self):
118    super().__init__()
119    filter_filename = None
120    outer_f = None
121    f = inspect.currentframe()
122    try:
123      if f is not None:
124        # The current frame is __init__. The first outer frame should be the
125        # caller.
126        outer_f = f.f_back
127        if outer_f is not None:
128          filter_filename = inspect.getsourcefile(outer_f)
129      self._filename = filter_filename
130      # This may be called repeatedly: once on entry by the superclass, then by
131      # each child context manager.
132      self._cached_set = None
133    finally:
134      # Avoid reference cycles, see:
135      # https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack
136      del f
137      del outer_f
138
139  def get_filtered_filenames(self):
140    if self._cached_set is not None:
141      return self._cached_set
142
143    filtered_filenames = frozenset((self._filename,))
144    if self.parent is not None:
145      filtered_filenames |= self.parent.get_filtered_filenames()
146    self._cached_set = filtered_filenames
147    return filtered_filenames
148
149
150def extract_stack():
151  """An eager-friendly alternative to traceback.extract_stack.
152
153  Returns:
154    A list-like FrameSummary containing StackFrame-like objects, which are
155    namedtuple-like objects with the following fields: filename, lineno, name,
156    line, meant to masquerade as traceback.FrameSummary objects.
157  """
158  # N.B ExtractStack in tf_stack.cc will drop this frame prior to
159  # traversing the stack.
160  # TODO(cheshire): Remove this function, use extract_stack_for_op or Python
161  # traceback module.
162  thread_key = _get_thread_key()
163  return _tf_stack.extract_stack(
164      _source_mapper_stacks[thread_key][-1].internal_map,
165      _source_filter_stacks[thread_key][-1].internal_set)
166
167
168# TODO(mdan): Revisit these - a single location is almost always sufficient.
169def extract_stack_for_op(c_op, stacklevel=1):
170  """Attaches the current stack trace to `c_op`.
171
172  Args:
173    c_op: a TF_Operation object.
174    stacklevel: An integer for ignoring Python wrapper stack frames.
175      The default value of 1 ignores this function from the frame.
176  """
177  # N.B ExtractStack in tf_stack.cc will drop this frame prior to
178  # traversing the stack.
179  thread_key = _get_thread_key()
180  _tf_stack.extract_stack_for_op(
181      _source_mapper_stacks[thread_key][-1].internal_map,
182      _source_filter_stacks[thread_key][-1].internal_set, c_op, stacklevel)
183
184
185StackSummary = _tf_stack.StackTraceWrapper
186FrameSummary = _tf_stack.StackFrame
187