xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/diagnostics/infra/formatter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import dataclasses
4import json
5import re
6import traceback
7from typing import Any, Callable, Union
8
9from torch._logging import LazyString
10from torch.onnx._internal.diagnostics.infra import sarif
11
12
13# A list of types in the SARIF module to support pretty printing.
14# This is solely for type annotation for the functions below.
15_SarifClass = Union[
16    sarif.SarifLog,
17    sarif.Run,
18    sarif.ReportingDescriptor,
19    sarif.Result,
20]
21
22
23def lazy_format_exception(exception: Exception) -> LazyString:
24    return LazyString(
25        lambda: "\n".join(
26            (
27                "```",
28                *traceback.format_exception(
29                    type(exception), exception, exception.__traceback__
30                ),
31                "```",
32            )
33        ),
34    )
35
36
37def snake_case_to_camel_case(s: str) -> str:
38    splits = s.split("_")
39    if len(splits) <= 1:
40        return s
41    return "".join([splits[0], *map(str.capitalize, splits[1:])])
42
43
44def camel_case_to_snake_case(s: str) -> str:
45    return re.sub(r"([A-Z])", r"_\1", s).lower()
46
47
48def kebab_case_to_snake_case(s: str) -> str:
49    return s.replace("-", "_")
50
51
52def _convert_key(
53    object: dict[str, Any] | Any, convert: Callable[[str], str]
54) -> dict[str, Any] | Any:
55    """Convert and update keys in a dictionary with "convert".
56
57    Any value that is a dictionary will be recursively updated.
58    Any value that is a list will be recursively searched.
59
60    Args:
61        object: The object to update.
62        convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`.
63
64    Returns:
65        The updated object.
66    """
67    if not isinstance(object, dict):
68        return object
69    new_dict = {}
70    for k, v in object.items():
71        new_k = convert(k)
72        if isinstance(v, dict):
73            new_v = _convert_key(v, convert)
74        elif isinstance(v, list):
75            new_v = [_convert_key(elem, convert) for elem in v]
76        else:
77            new_v = v
78        if new_v is None:
79            # Otherwise unnecessarily bloated sarif log with "null"s.
80            continue
81        if new_v == -1:
82            # WAR: -1 as default value shouldn't be logged into sarif.
83            continue
84
85        new_dict[new_k] = new_v
86
87    return new_dict
88
89
90def sarif_to_json(attr_cls_obj: _SarifClass, indent: str | None = " ") -> str:
91    dict = dataclasses.asdict(attr_cls_obj)
92    dict = _convert_key(dict, snake_case_to_camel_case)
93    return json.dumps(dict, indent=indent, separators=(",", ":"))
94
95
96def format_argument(obj: Any) -> str:
97    return f"{type(obj)}"
98
99
100def display_name(fn: Callable) -> str:
101    if hasattr(fn, "__qualname__"):
102        return fn.__qualname__
103    elif hasattr(fn, "__name__"):
104        return fn.__name__
105    else:
106        return str(fn)
107