1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions used to extract and analyze stacks. Faster than Python libs.""" 16# pylint: disable=g-bad-name 17import collections 18import inspect 19import threading 20 21# TODO(b/138203821): change to from ...util import ... once the bug is fixed. 22from tensorflow.python.util import _tf_stack 23 24# Generally such lookups should be done using `threading.local()`. See 25# https://blogs.gnome.org/jamesh/2008/06/11/tls-python/ for a detailed 26# explanation of why. However the transform stacks are expected to be empty 27# when a thread is joined, so reusing the key does not introduce a correctness 28# issue. Moreover, get_ident is faster than storing and retrieving a unique 29# key in a thread local store. 30_get_thread_key = threading.get_ident 31 32 33# TODO(mdan): Move these to C++ as well. 34# Moving to C++ can further avoid extra copies made by get_effective_map. 35_source_mapper_stacks = collections.defaultdict(lambda: [SentinelMapper()]) 36_source_filter_stacks = collections.defaultdict(lambda: [SentinelFilter()]) 37 38 39class StackTraceTransform(object): 40 """Base class for stack trace transformation functions.""" 41 42 _stack_dict = None # Subclasses should override 43 _thread_key = None 44 45 def __enter__(self): 46 # Any given instance is assumed to be used by a single thread, which reduces 47 # expensive thread local lookups. 48 if self._thread_key is None: 49 self._thread_key = _get_thread_key() 50 else: 51 assert self._thread_key == _get_thread_key(), 'Shared across threads?' 52 53 stack = self._stack_dict[self._thread_key] 54 self.parent = stack[-1] 55 stack.append(self) 56 self.update() 57 return self 58 59 def __exit__(self, unused_type, unused_value, unused_traceback): 60 top = self._stack_dict[self._thread_key].pop() 61 assert top is self, 'Concurrent access?' 62 63 def update(self): 64 raise NotImplementedError('subclasses need to override this') 65 66 67class StackTraceMapper(StackTraceTransform): 68 """Allows remapping traceback information to different source code.""" 69 _stack_dict = _source_mapper_stacks 70 71 def __init__(self): 72 self.internal_map = _tf_stack.PyBindSourceMap() 73 74 def update(self): 75 self.internal_map.update_to(tuple(self.get_effective_source_map().items())) 76 77 def get_effective_source_map(self): 78 """Returns a map (filename, lineno) -> (filename, lineno, function_name).""" 79 raise NotImplementedError('subclasses need to override this') 80 81 82EMPTY_DICT = {} 83 84 85class SentinelMapper(StackTraceMapper): 86 87 def get_effective_source_map(self): 88 return EMPTY_DICT 89 90 91class StackTraceFilter(StackTraceTransform): 92 """Allows filtering traceback information by removing superfluous frames.""" 93 _stack_dict = _source_filter_stacks 94 95 def __init__(self): 96 self.internal_set = _tf_stack.PyBindFileSet() 97 98 def update(self): 99 self.internal_set.update_to(set(self.get_filtered_filenames())) 100 101 def get_filtered_filenames(self): 102 raise NotImplementedError('subclasses need to override this') 103 104 105EMPTY_SET = frozenset() 106 107 108class SentinelFilter(StackTraceFilter): 109 110 def get_filtered_filenames(self): 111 return EMPTY_SET 112 113 114class CurrentModuleFilter(StackTraceFilter): 115 """Filters stack frames from the module where this is used (best effort).""" 116 117 def __init__(self): 118 super().__init__() 119 filter_filename = None 120 outer_f = None 121 f = inspect.currentframe() 122 try: 123 if f is not None: 124 # The current frame is __init__. The first outer frame should be the 125 # caller. 126 outer_f = f.f_back 127 if outer_f is not None: 128 filter_filename = inspect.getsourcefile(outer_f) 129 self._filename = filter_filename 130 # This may be called repeatedly: once on entry by the superclass, then by 131 # each child context manager. 132 self._cached_set = None 133 finally: 134 # Avoid reference cycles, see: 135 # https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack 136 del f 137 del outer_f 138 139 def get_filtered_filenames(self): 140 if self._cached_set is not None: 141 return self._cached_set 142 143 filtered_filenames = frozenset((self._filename,)) 144 if self.parent is not None: 145 filtered_filenames |= self.parent.get_filtered_filenames() 146 self._cached_set = filtered_filenames 147 return filtered_filenames 148 149 150def extract_stack(): 151 """An eager-friendly alternative to traceback.extract_stack. 152 153 Returns: 154 A list-like FrameSummary containing StackFrame-like objects, which are 155 namedtuple-like objects with the following fields: filename, lineno, name, 156 line, meant to masquerade as traceback.FrameSummary objects. 157 """ 158 # N.B ExtractStack in tf_stack.cc will drop this frame prior to 159 # traversing the stack. 160 # TODO(cheshire): Remove this function, use extract_stack_for_op or Python 161 # traceback module. 162 thread_key = _get_thread_key() 163 return _tf_stack.extract_stack( 164 _source_mapper_stacks[thread_key][-1].internal_map, 165 _source_filter_stacks[thread_key][-1].internal_set) 166 167 168# TODO(mdan): Revisit these - a single location is almost always sufficient. 169def extract_stack_for_op(c_op, stacklevel=1): 170 """Attaches the current stack trace to `c_op`. 171 172 Args: 173 c_op: a TF_Operation object. 174 stacklevel: An integer for ignoring Python wrapper stack frames. 175 The default value of 1 ignores this function from the frame. 176 """ 177 # N.B ExtractStack in tf_stack.cc will drop this frame prior to 178 # traversing the stack. 179 thread_key = _get_thread_key() 180 _tf_stack.extract_stack_for_op( 181 _source_mapper_stacks[thread_key][-1].internal_map, 182 _source_filter_stacks[thread_key][-1].internal_set, c_op, stacklevel) 183 184 185StackSummary = _tf_stack.StackTraceWrapper 186FrameSummary = _tf_stack.StackFrame 187