1# mypy: allow-untyped-defs 2from types import TracebackType 3from typing import List, Optional 4import tempfile 5import traceback 6import contextlib 7import inspect 8import os.path 9 10# This file contains utilities for ensuring dynamically compile()'d 11# code fragments display their line numbers in backtraces. 12# 13# The constraints: 14# 15# - We don't have control over the user exception printer (in particular, 16# we cannot assume the linecache trick will work, c.f. 17# https://stackoverflow.com/q/50515651/23845 ) 18# 19# - We don't want to create temporary files every time we compile() 20# some code; file creation should happen lazily only at exception 21# time. Arguably, you *should* be willing to write out your 22# generated Python code to file system, but in some situations 23# (esp. library code) it would violate user expectation to write 24# to the file system, so we try to avoid it. In particular, we'd 25# like to keep the files around, so users can open up the files 26# mentioned in the trace; if the file is invisible, we want to 27# avoid clogging up the filesystem. 28# 29# If this is not a constraint for you, there is a substantially simpler 30# way to implement the functionality in this PR: instead of using 31# eval/exec directly, just always write a Python file to filesystem 32# and compile that. 33# 34# - You have control over a context where the compiled code will get 35# executed, so that we can interpose while the stack is unwinding 36# (otherwise, we have no way to interpose on the exception printing 37# process.) 38# 39# There are two things you have to do to make use of the utilities here: 40# 41# - When you compile your source code, you must save its string source 42# in its f_globals under the magic name "__compile_source__" 43# 44# - Before running the compiled code, enter the 45# report_compile_source_on_error() context manager. 46 47@contextlib.contextmanager 48def report_compile_source_on_error(): 49 try: 50 yield 51 except Exception as exc: 52 tb = exc.__traceback__ 53 54 # Walk the traceback, looking for frames that have 55 # source attached 56 stack = [] 57 while tb is not None: 58 filename = tb.tb_frame.f_code.co_filename 59 source = tb.tb_frame.f_globals.get("__compile_source__") 60 61 if filename == "<string>" and source is not None: 62 # What black magic are we doing here? Intuitively, what 63 # we would like to do is overwrite the co_filename on any 64 # frames that were generated from exec/eval so that they 65 # point to a temporary file that has the actual line 66 # information, so Python's default error printer can print 67 # useful line information on it. 68 # 69 # Writing out the temporary file is easy. But overwriting 70 # co_filename is not! You can't modify the code object 71 # associated with a frame. You can, however, reconstruct 72 # a traceback with entirely new frames from scratch, so that's 73 # what we do. But there's another problem, which is how to 74 # make the frame? 75 # 76 # The black magic is we make a frankenstein frame and code 77 # object which resembles the original frame/code enough so 78 # that it will print properly under traceback and the default 79 # error printer, but IT IS NOT THE ORIGINAL FRAME (you 80 # couldn't, e.g., execute its code with different variables 81 # and expect it to work.) 82 83 # Don't delete the temporary file so the user can inspect it 84 # TODO: This creates a temporary file for every frame, but we 85 # technically only need one per distinct __compile_source__ 86 with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".py") as f: 87 f.write(source) 88 # Create a frame. Python doesn't let you construct 89 # FrameType directly, so just make one with compile 90 frame = tb.tb_frame 91 code = compile('__inspect_currentframe()', f.name, 'eval') 92 code = code.replace(co_name=frame.f_code.co_name) 93 # Python 3.11 only 94 if hasattr(frame.f_code, 'co_linetable'): 95 # We can't copy ALL of the metadata over, because you 96 # can cause Python to segfault this way. What exactly 97 # do we need? We need enough information for 98 # traceback to be able to print the exception 99 # correctly. Code reading Lib/traceback.py reveals 100 # that traceback calls code.co_positions() in order to 101 # get the augmented line/col numbers. Objects/codeobject.c, 102 # specifically _PyCode_InitAddressRange, reveals that 103 # this iterator is initialized from co_linetable and 104 # co_firstfileno. So copy these we must! 105 code = code.replace( # type: ignore[call-arg] 106 co_linetable=frame.f_code.co_linetable, # type: ignore[attr-defined] 107 co_firstlineno=frame.f_code.co_firstlineno, # type: ignore[attr-defined] 108 ) 109 fake_frame = eval( 110 code, 111 frame.f_globals, 112 { 113 **frame.f_locals, 114 '__inspect_currentframe': inspect.currentframe 115 } 116 ) 117 fake_tb = TracebackType( 118 None, fake_frame, tb.tb_lasti, tb.tb_lineno 119 ) 120 stack.append(fake_tb) 121 else: 122 stack.append(tb) 123 124 tb = tb.tb_next 125 126 # Reconstruct the linked list 127 tb_next = None 128 for tb in reversed(stack): 129 tb.tb_next = tb_next 130 tb_next = tb 131 132 raise exc.with_traceback(tb_next) # noqa: B904 133 134def shorten_filename(fn, *, base=None): 135 """Shorten a source filepath, with the assumption that torch/ subdirectories don't need to be shown to user.""" 136 if base is None: 137 base = os.path.dirname(os.path.dirname(__file__)) 138 # Truncate torch/foo.py to foo.py 139 try: 140 prefix = os.path.commonpath([fn, base]) 141 except ValueError: 142 return fn 143 else: 144 return fn[len(prefix) + 1:] 145 146def format_frame(frame, *, base=None, line=False): 147 """ 148 Format a FrameSummary in a short way, without printing full absolute path or code. 149 150 The idea is the result fits on a single line. 151 """ 152 extra_line = "" 153 if line: 154 extra_line = f"{frame.line} # " 155 return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}" 156 157def format_traceback_short(tb): 158 """Format a TracebackType in a short way, printing only the inner-most frame.""" 159 return format_frame(traceback.extract_tb(tb)[-1]) 160 161class CapturedTraceback: 162 __slots__ = ['tb', 'skip'] 163 164 def __init__(self, tb, skip=0): 165 self.tb = tb 166 self.skip = skip 167 168 def cleanup(self): 169 self.tb = None 170 171 def summary(self): 172 import torch._C._profiler 173 174 if self.tb is None: 175 # TODO: Maybe indicate that the traceback was elided? 176 return traceback.StackSummary() 177 178 return _extract_symbolized_tb( 179 torch._C._profiler.symbolize_tracebacks([self.tb])[0], 180 self.skip 181 ) 182 183 def __getstate__(self): 184 return (None, { 185 'tb': None, # TB is not pickleable 186 'skip': self.skip, 187 }) 188 189 @staticmethod 190 def extract(*, script=False, cpp=False, skip=0): 191 """ 192 Like traceback.extract_stack(), but faster (approximately 20x faster); it 193 is fast enough that you can unconditionally log stacks this way as part of 194 normal execution. It returns a torch._C._profiler.CapturedTraceback 195 object that must be formatted specially with format_captured_tb. 196 197 By default, this only reports Python backtraces (like extract_stack). You 198 can set the script/cpp kwargs to also turn on TorchScript/C++ trace 199 reporting. 200 """ 201 import torch._C._profiler 202 203 if script or cpp: 204 assert skip == 0, "skip with script/cpp NYI" 205 206 return CapturedTraceback( 207 torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp), 208 # Elide extract() frame if we don't have script/cpp frames. If 209 # we do have those frames, it doesn't work so force zero. 210 0 if script or cpp else skip + 1 211 ) 212 213 def format(self): 214 """ 215 Formats a single torch._C._profiler.CapturedTraceback into a list of 216 strings equivalent to the output of traceback.format_list. Note that if 217 pass it CapturedTraceback with C++ traces, it is better not to use this 218 function and use the batch formatting API format_captured_tbs to amortize 219 the cost of symbolization 220 """ 221 return traceback.format_list(self.summary()) 222 223 @staticmethod 224 def format_all(tbs): 225 """ 226 Bulk version of CapturedTraceback.format. Returns a list of list of strings. 227 """ 228 import torch._C._profiler 229 230 # Directly populate tracebacks that already have cached summaries 231 rs: List[Optional[List[str]]] = [] 232 delayed_idxs = [] 233 for i, tb in enumerate(tbs): 234 if tb.tb is None: 235 rs.append([]) 236 else: 237 rs.append(None) 238 delayed_idxs.append(i) 239 240 stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs]) 241 for i, stb in zip(delayed_idxs, stbs): 242 rs[i] = traceback.format_list(tbs[i].summary()) 243 244 return rs 245 246 247def _extract_symbolized_tb(tb, skip): 248 """ 249 Given a symbolized traceback from symbolize_tracebacks, return a StackSummary object of 250 pre-processed stack trace entries. 251 """ 252 stack = traceback.StackSummary() 253 for f in reversed(tb[skip:]): 254 stack.append(traceback.FrameSummary(f['filename'], f['line'], f['name'])) 255 return stack 256