1from __future__ import annotations 2 3import dataclasses 4import json 5import re 6import traceback 7from typing import Any, Callable, Union 8 9from torch._logging import LazyString 10from torch.onnx._internal.diagnostics.infra import sarif 11 12 13# A list of types in the SARIF module to support pretty printing. 14# This is solely for type annotation for the functions below. 15_SarifClass = Union[ 16 sarif.SarifLog, 17 sarif.Run, 18 sarif.ReportingDescriptor, 19 sarif.Result, 20] 21 22 23def lazy_format_exception(exception: Exception) -> LazyString: 24 return LazyString( 25 lambda: "\n".join( 26 ( 27 "```", 28 *traceback.format_exception( 29 type(exception), exception, exception.__traceback__ 30 ), 31 "```", 32 ) 33 ), 34 ) 35 36 37def snake_case_to_camel_case(s: str) -> str: 38 splits = s.split("_") 39 if len(splits) <= 1: 40 return s 41 return "".join([splits[0], *map(str.capitalize, splits[1:])]) 42 43 44def camel_case_to_snake_case(s: str) -> str: 45 return re.sub(r"([A-Z])", r"_\1", s).lower() 46 47 48def kebab_case_to_snake_case(s: str) -> str: 49 return s.replace("-", "_") 50 51 52def _convert_key( 53 object: dict[str, Any] | Any, convert: Callable[[str], str] 54) -> dict[str, Any] | Any: 55 """Convert and update keys in a dictionary with "convert". 56 57 Any value that is a dictionary will be recursively updated. 58 Any value that is a list will be recursively searched. 59 60 Args: 61 object: The object to update. 62 convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`. 63 64 Returns: 65 The updated object. 66 """ 67 if not isinstance(object, dict): 68 return object 69 new_dict = {} 70 for k, v in object.items(): 71 new_k = convert(k) 72 if isinstance(v, dict): 73 new_v = _convert_key(v, convert) 74 elif isinstance(v, list): 75 new_v = [_convert_key(elem, convert) for elem in v] 76 else: 77 new_v = v 78 if new_v is None: 79 # Otherwise unnecessarily bloated sarif log with "null"s. 80 continue 81 if new_v == -1: 82 # WAR: -1 as default value shouldn't be logged into sarif. 83 continue 84 85 new_dict[new_k] = new_v 86 87 return new_dict 88 89 90def sarif_to_json(attr_cls_obj: _SarifClass, indent: str | None = " ") -> str: 91 dict = dataclasses.asdict(attr_cls_obj) 92 dict = _convert_key(dict, snake_case_to_camel_case) 93 return json.dumps(dict, indent=indent, separators=(",", ":")) 94 95 96def format_argument(obj: Any) -> str: 97 return f"{type(obj)}" 98 99 100def display_name(fn: Callable) -> str: 101 if hasattr(fn, "__qualname__"): 102 return fn.__qualname__ 103 elif hasattr(fn, "__name__"): 104 return fn.__name__ 105 else: 106 return str(fn) 107