1# mypy: allow-untyped-defs 2"""A diagnostic context based on SARIF.""" 3 4from __future__ import annotations 5 6import contextlib 7import dataclasses 8import gzip 9import logging 10from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar 11from typing_extensions import Self 12 13from torch.onnx._internal.diagnostics import infra 14from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils 15from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version 16 17 18# This is a workaround for mypy not supporting Self from typing_extensions. 19_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic") 20diagnostic_logger: logging.Logger = logging.getLogger(__name__) 21 22 23@dataclasses.dataclass 24class Diagnostic: 25 rule: infra.Rule 26 level: infra.Level 27 message: str | None = None 28 locations: list[infra.Location] = dataclasses.field(default_factory=list) 29 stacks: list[infra.Stack] = dataclasses.field(default_factory=list) 30 graphs: list[infra.Graph] = dataclasses.field(default_factory=list) 31 thread_flow_locations: list[infra.ThreadFlowLocation] = dataclasses.field( 32 default_factory=list 33 ) 34 additional_messages: list[str] = dataclasses.field(default_factory=list) 35 tags: list[infra.Tag] = dataclasses.field(default_factory=list) 36 source_exception: Exception | None = None 37 """The exception that caused this diagnostic to be created.""" 38 logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger) 39 """The logger for this diagnostic. Defaults to 'diagnostic_logger' which has the same 40 log level setting with `DiagnosticOptions.verbosity_level`.""" 41 _current_log_section_depth: int = 0 42 43 def __post_init__(self) -> None: 44 pass 45 46 def sarif(self) -> sarif.Result: 47 """Returns the SARIF Result representation of this diagnostic.""" 48 message = self.message or self.rule.message_default_template 49 if self.additional_messages: 50 additional_message = "\n".join(self.additional_messages) 51 message_markdown = ( 52 f"{message}\n\n## Additional Message:\n\n{additional_message}" 53 ) 54 else: 55 message_markdown = message 56 57 kind: Literal["informational", "fail"] = ( 58 "informational" if self.level == infra.Level.NONE else "fail" 59 ) 60 61 sarif_result = sarif.Result( 62 message=sarif.Message(text=message, markdown=message_markdown), 63 level=self.level.name.lower(), # type: ignore[arg-type] 64 rule_id=self.rule.id, 65 kind=kind, 66 ) 67 sarif_result.locations = [location.sarif() for location in self.locations] 68 sarif_result.stacks = [stack.sarif() for stack in self.stacks] 69 sarif_result.graphs = [graph.sarif() for graph in self.graphs] 70 sarif_result.code_flows = [ 71 sarif.CodeFlow( 72 thread_flows=[ 73 sarif.ThreadFlow( 74 locations=[loc.sarif() for loc in self.thread_flow_locations] 75 ) 76 ] 77 ) 78 ] 79 sarif_result.properties = sarif.PropertyBag( 80 tags=[tag.value for tag in self.tags] 81 ) 82 return sarif_result 83 84 def with_location(self: Self, location: infra.Location) -> Self: 85 """Adds a location to the diagnostic.""" 86 self.locations.append(location) 87 return self 88 89 def with_thread_flow_location( 90 self: Self, location: infra.ThreadFlowLocation 91 ) -> Self: 92 """Adds a thread flow location to the diagnostic.""" 93 self.thread_flow_locations.append(location) 94 return self 95 96 def with_stack(self: Self, stack: infra.Stack) -> Self: 97 """Adds a stack to the diagnostic.""" 98 self.stacks.append(stack) 99 return self 100 101 def with_graph(self: Self, graph: infra.Graph) -> Self: 102 """Adds a graph to the diagnostic.""" 103 self.graphs.append(graph) 104 return self 105 106 @contextlib.contextmanager 107 def log_section( 108 self, level: int, message: str, *args, **kwargs 109 ) -> Generator[None, None, None]: 110 """ 111 Context manager for a section of log messages, denoted by a title message and increased indentation. 112 113 Same api as `logging.Logger.log`. 114 115 This context manager logs the given title at the specified log level, increases the current 116 section depth for subsequent log messages, and ensures that the section depth is decreased 117 again when exiting the context. 118 119 Args: 120 level: The log level. 121 message: The title message to log. 122 *args: The arguments to the message. Use `LazyString` to defer the 123 expensive evaluation of the arguments until the message is actually logged. 124 **kwargs: The keyword arguments for `logging.Logger.log`. 125 126 Yields: 127 None: This context manager does not yield any value. 128 129 Example: 130 >>> with DiagnosticContext("DummyContext", "1.0"): 131 ... rule = infra.Rule("RuleID", "DummyRule", "Rule message") 132 ... diagnostic = Diagnostic(rule, infra.Level.WARNING) 133 ... with diagnostic.log_section(logging.INFO, "My Section"): 134 ... diagnostic.log(logging.INFO, "My Message") 135 ... with diagnostic.log_section(logging.INFO, "My Subsection"): 136 ... diagnostic.log(logging.INFO, "My Submessage") 137 ... diagnostic.additional_messages 138 ['## My Section', 'My Message', '### My Subsection', 'My Submessage'] 139 """ 140 if self.logger.isEnabledFor(level): 141 indented_format_message = ( 142 f"##{'#' * self._current_log_section_depth } {message}" 143 ) 144 self.log( 145 level, 146 indented_format_message, 147 *args, 148 **kwargs, 149 ) 150 self._current_log_section_depth += 1 151 try: 152 yield 153 finally: 154 self._current_log_section_depth -= 1 155 156 def log(self, level: int, message: str, *args, **kwargs) -> None: 157 """Logs a message within the diagnostic. Same api as `logging.Logger.log`. 158 159 If logger is not enabled for the given level, the message will not be logged. 160 Otherwise, the message will be logged and also added to the diagnostic's additional_messages. 161 162 The default setting for `DiagnosticOptions.verbosity_level` is `logging.INFO`. Based on this default, 163 the log level recommendations are as follows. If you've set a different default verbosity level in your 164 application, please adjust accordingly: 165 166 - logging.ERROR: Log any events leading to application failure. 167 - logging.WARNING: Log events that might result in application issues or failures, although not guaranteed. 168 - logging.INFO: Log general useful information, ensuring minimal performance overhead. 169 - logging.DEBUG: Log detailed debug information, which might affect performance when logged. 170 171 Args: 172 level: The log level. 173 message: The message to log. 174 *args: The arguments to the message. Use `LazyString` to defer the 175 expensive evaluation of the arguments until the message is actually logged. 176 **kwargs: The keyword arguments for `logging.Logger.log`. 177 """ 178 if self.logger.isEnabledFor(level): 179 formatted_message = message % args 180 self.logger.log(level, formatted_message, **kwargs) 181 self.additional_messages.append(formatted_message) 182 183 def debug(self, message: str, *args, **kwargs) -> None: 184 """Logs a debug message within the diagnostic. Same api as logging.Logger.debug. 185 186 Checkout `log` for more details. 187 """ 188 self.log(logging.DEBUG, message, *args, **kwargs) 189 190 def info(self, message: str, *args, **kwargs) -> None: 191 """Logs an info message within the diagnostic. Same api as logging.Logger.info. 192 193 Checkout `log` for more details. 194 """ 195 self.log(logging.INFO, message, *args, **kwargs) 196 197 def warning(self, message: str, *args, **kwargs) -> None: 198 """Logs a warning message within the diagnostic. Same api as logging.Logger.warning. 199 200 Checkout `log` for more details. 201 """ 202 self.log(logging.WARNING, message, *args, **kwargs) 203 204 def error(self, message: str, *args, **kwargs) -> None: 205 """Logs an error message within the diagnostic. Same api as logging.Logger.error. 206 207 Checkout `log` for more details. 208 """ 209 self.log(logging.ERROR, message, *args, **kwargs) 210 211 def log_source_exception(self, level: int, exception: Exception) -> None: 212 """Logs a source exception within the diagnostic. 213 214 Invokes `log_section` and `log` to log the exception in markdown section format. 215 """ 216 self.source_exception = exception 217 with self.log_section(level, "Exception log"): 218 self.log(level, "%s", formatter.lazy_format_exception(exception)) 219 220 def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack: 221 """Records the current Python call stack.""" 222 frames_to_skip += 1 # Skip this function. 223 stack = utils.python_call_stack(frames_to_skip=frames_to_skip) 224 self.with_stack(stack) 225 if len(stack.frames) > 0: 226 self.with_location(stack.frames[0].location) 227 return stack 228 229 def record_python_call( 230 self, 231 fn: Callable, 232 state: Mapping[str, str], 233 message: str | None = None, 234 frames_to_skip: int = 0, 235 ) -> infra.ThreadFlowLocation: 236 """Records a python call as one thread flow step.""" 237 frames_to_skip += 1 # Skip this function. 238 stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5) 239 location = utils.function_location(fn) 240 location.message = message 241 # Add function location to the top of the stack. 242 stack.frames.insert(0, infra.StackFrame(location=location)) 243 thread_flow_location = infra.ThreadFlowLocation( 244 location=location, 245 state=state, 246 index=len(self.thread_flow_locations), 247 stack=stack, 248 ) 249 self.with_thread_flow_location(thread_flow_location) 250 return thread_flow_location 251 252 253class RuntimeErrorWithDiagnostic(RuntimeError): 254 """Runtime error with enclosed diagnostic information.""" 255 256 def __init__(self, diagnostic: Diagnostic): 257 super().__init__(diagnostic.message) 258 self.diagnostic = diagnostic 259 260 261@dataclasses.dataclass 262class DiagnosticContext(Generic[_Diagnostic]): 263 name: str 264 version: str 265 options: infra.DiagnosticOptions = dataclasses.field( 266 default_factory=infra.DiagnosticOptions 267 ) 268 diagnostics: list[_Diagnostic] = dataclasses.field(init=False, default_factory=list) 269 # TODO(bowbao): Implement this. 270 # _invocation: infra.Invocation = dataclasses.field(init=False) 271 _inflight_diagnostics: list[_Diagnostic] = dataclasses.field( 272 init=False, default_factory=list 273 ) 274 _previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING) 275 logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger) 276 _bound_diagnostic_type: type = dataclasses.field(init=False, default=Diagnostic) 277 278 def __enter__(self): 279 self._previous_log_level = self.logger.level 280 self.logger.setLevel(self.options.verbosity_level) 281 return self 282 283 def __exit__(self, exc_type, exc_val, exc_tb): 284 self.logger.setLevel(self._previous_log_level) 285 return None 286 287 def sarif(self) -> sarif.Run: 288 """Returns the SARIF Run object.""" 289 unique_rules = {diagnostic.rule for diagnostic in self.diagnostics} 290 return sarif.Run( 291 sarif.Tool( 292 driver=sarif.ToolComponent( 293 name=self.name, 294 version=self.version, 295 rules=[rule.sarif() for rule in unique_rules], 296 ) 297 ), 298 results=[diagnostic.sarif() for diagnostic in self.diagnostics], 299 ) 300 301 def sarif_log(self) -> sarif.SarifLog: # type: ignore[name-defined] 302 """Returns the SARIF Log object.""" 303 return sarif.SarifLog( 304 version=sarif_version.SARIF_VERSION, 305 schema_uri=sarif_version.SARIF_SCHEMA_LINK, 306 runs=[self.sarif()], 307 ) 308 309 def to_json(self) -> str: 310 return formatter.sarif_to_json(self.sarif_log()) 311 312 def dump(self, file_path: str, compress: bool = False) -> None: 313 """Dumps the SARIF log to a file.""" 314 if compress: 315 with gzip.open(file_path, "wt") as f: 316 f.write(self.to_json()) 317 else: 318 with open(file_path, "w") as f: 319 f.write(self.to_json()) 320 321 def log(self, diagnostic: _Diagnostic) -> None: 322 """Logs a diagnostic. 323 324 This method should be used only after all the necessary information for the diagnostic 325 has been collected. 326 327 Args: 328 diagnostic: The diagnostic to add. 329 """ 330 if not isinstance(diagnostic, self._bound_diagnostic_type): 331 raise TypeError( 332 f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}" 333 ) 334 if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING: # type: ignore[attr-defined] 335 diagnostic.level = infra.Level.ERROR # type: ignore[attr-defined] 336 self.diagnostics.append(diagnostic) # type: ignore[arg-type] 337 338 def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None: 339 """Logs a diagnostic and raises an exception if it is an error. 340 341 Use this method for logging non inflight diagnostics where diagnostic level is not known or 342 lower than ERROR. If it is always expected raise, use `log` and explicit 343 `raise` instead. Otherwise there is no way to convey the message that it always 344 raises to Python intellisense and type checking tools. 345 346 This method should be used only after all the necessary information for the diagnostic 347 has been collected. 348 349 Args: 350 diagnostic: The diagnostic to add. 351 """ 352 self.log(diagnostic) 353 if diagnostic.level == infra.Level.ERROR: 354 if diagnostic.source_exception is not None: 355 raise diagnostic.source_exception 356 raise RuntimeErrorWithDiagnostic(diagnostic) 357 358 @contextlib.contextmanager 359 def add_inflight_diagnostic( 360 self, diagnostic: _Diagnostic 361 ) -> Generator[_Diagnostic, None, None]: 362 """Adds a diagnostic to the context. 363 364 Use this method to add diagnostics that are not created by the context. 365 Args: 366 diagnostic: The diagnostic to add. 367 """ 368 self._inflight_diagnostics.append(diagnostic) 369 try: 370 yield diagnostic 371 finally: 372 self._inflight_diagnostics.pop() 373 374 def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None: 375 """Pushes a diagnostic to the inflight diagnostics stack. 376 377 Args: 378 diagnostic: The diagnostic to push. 379 380 Raises: 381 ValueError: If the rule is not supported by the tool. 382 """ 383 self._inflight_diagnostics.append(diagnostic) 384 385 def pop_inflight_diagnostic(self) -> _Diagnostic: 386 """Pops the last diagnostic from the inflight diagnostics stack. 387 388 Returns: 389 The popped diagnostic. 390 """ 391 return self._inflight_diagnostics.pop() 392 393 def inflight_diagnostic(self, rule: infra.Rule | None = None) -> _Diagnostic: 394 if rule is None: 395 # TODO(bowbao): Create builtin-rules and create diagnostic using that. 396 if len(self._inflight_diagnostics) <= 0: 397 raise AssertionError("No inflight diagnostics") 398 399 return self._inflight_diagnostics[-1] 400 else: 401 for diagnostic in reversed(self._inflight_diagnostics): 402 if diagnostic.rule == rule: 403 return diagnostic 404 raise AssertionError(f"No inflight diagnostic for rule {rule.name}") 405