1from __future__ import annotations 2 3import functools 4import inspect 5import traceback 6from typing import Any, Callable, Mapping, Sequence 7 8from torch.onnx._internal.diagnostics.infra import _infra, formatter 9 10 11def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame: 12 """Returns a StackFrame for the given traceback.FrameSummary.""" 13 snippet = frame.line 14 15 return _infra.StackFrame( 16 location=_infra.Location( 17 uri=frame.filename, 18 line=frame.lineno, 19 snippet=snippet, 20 function=frame.name, 21 message=snippet, 22 ) 23 ) 24 25 26def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack: 27 """Returns the current Python call stack.""" 28 if frames_to_skip < 0: 29 raise ValueError("frames_to_skip must be non-negative") 30 if frames_to_log < 0: 31 raise ValueError("frames_to_log must be non-negative") 32 frames_to_skip += 1 # Skip this function. 33 stack = _infra.Stack() 34 # Frames are returned in order of oldest to newest. 35 frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log) 36 frames.reverse() 37 stack.frames = [python_frame(frame) for frame in frames[frames_to_skip:]] 38 stack.message = "Python call stack" 39 return stack 40 41 42@functools.lru_cache 43def _function_source_info(fn: Callable) -> tuple[Sequence[str], int, str | None]: 44 """Returns the source lines, line number, and source file path for the given function. 45 46 Essentially, inspect.getsourcelines() and inspect.getsourcefile() combined. 47 Caching is applied to reduce the performance impact of this function. 48 """ 49 source_lines, lineno = inspect.getsourcelines(fn) 50 return source_lines, lineno, inspect.getsourcefile(fn) 51 52 53def function_location(fn: Callable) -> _infra.Location: 54 """Returns a Location for the given function.""" 55 source_lines, lineno, uri = _function_source_info(fn) 56 snippet = source_lines[0].strip() if len(source_lines) > 0 else "<unknown>" 57 return _infra.Location( 58 uri=uri, 59 line=lineno, 60 snippet=snippet, 61 message=formatter.display_name(fn), 62 ) 63 64 65def function_state( 66 fn: Callable, args: tuple[Any, ...], kwargs: dict[str, Any] 67) -> Mapping[str, Any]: 68 bind = inspect.signature(fn).bind(*args, **kwargs) 69 return bind.arguments 70