xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/diagnostics/infra/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import functools
4import inspect
5import traceback
6from typing import Any, Callable, Mapping, Sequence
7
8from torch.onnx._internal.diagnostics.infra import _infra, formatter
9
10
11def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame:
12    """Returns a StackFrame for the given traceback.FrameSummary."""
13    snippet = frame.line
14
15    return _infra.StackFrame(
16        location=_infra.Location(
17            uri=frame.filename,
18            line=frame.lineno,
19            snippet=snippet,
20            function=frame.name,
21            message=snippet,
22        )
23    )
24
25
26def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack:
27    """Returns the current Python call stack."""
28    if frames_to_skip < 0:
29        raise ValueError("frames_to_skip must be non-negative")
30    if frames_to_log < 0:
31        raise ValueError("frames_to_log must be non-negative")
32    frames_to_skip += 1  # Skip this function.
33    stack = _infra.Stack()
34    # Frames are returned in order of oldest to newest.
35    frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log)
36    frames.reverse()
37    stack.frames = [python_frame(frame) for frame in frames[frames_to_skip:]]
38    stack.message = "Python call stack"
39    return stack
40
41
42@functools.lru_cache
43def _function_source_info(fn: Callable) -> tuple[Sequence[str], int, str | None]:
44    """Returns the source lines, line number, and source file path for the given function.
45
46    Essentially, inspect.getsourcelines() and inspect.getsourcefile() combined.
47    Caching is applied to reduce the performance impact of this function.
48    """
49    source_lines, lineno = inspect.getsourcelines(fn)
50    return source_lines, lineno, inspect.getsourcefile(fn)
51
52
53def function_location(fn: Callable) -> _infra.Location:
54    """Returns a Location for the given function."""
55    source_lines, lineno, uri = _function_source_info(fn)
56    snippet = source_lines[0].strip() if len(source_lines) > 0 else "<unknown>"
57    return _infra.Location(
58        uri=uri,
59        line=lineno,
60        snippet=snippet,
61        message=formatter.display_name(fn),
62    )
63
64
65def function_state(
66    fn: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]
67) -> Mapping[str, Any]:
68    bind = inspect.signature(fn).bind(*args, **kwargs)
69    return bind.arguments
70