xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/diagnostics/infra/context.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""A diagnostic context based on SARIF."""
3
4from __future__ import annotations
5
6import contextlib
7import dataclasses
8import gzip
9import logging
10from typing import Callable, Generator, Generic, Literal, Mapping, TypeVar
11from typing_extensions import Self
12
13from torch.onnx._internal.diagnostics import infra
14from torch.onnx._internal.diagnostics.infra import formatter, sarif, utils
15from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
16
17
18# This is a workaround for mypy not supporting Self from typing_extensions.
19_Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic")
20diagnostic_logger: logging.Logger = logging.getLogger(__name__)
21
22
23@dataclasses.dataclass
24class Diagnostic:
25    rule: infra.Rule
26    level: infra.Level
27    message: str | None = None
28    locations: list[infra.Location] = dataclasses.field(default_factory=list)
29    stacks: list[infra.Stack] = dataclasses.field(default_factory=list)
30    graphs: list[infra.Graph] = dataclasses.field(default_factory=list)
31    thread_flow_locations: list[infra.ThreadFlowLocation] = dataclasses.field(
32        default_factory=list
33    )
34    additional_messages: list[str] = dataclasses.field(default_factory=list)
35    tags: list[infra.Tag] = dataclasses.field(default_factory=list)
36    source_exception: Exception | None = None
37    """The exception that caused this diagnostic to be created."""
38    logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
39    """The logger for this diagnostic. Defaults to 'diagnostic_logger' which has the same
40    log level setting with `DiagnosticOptions.verbosity_level`."""
41    _current_log_section_depth: int = 0
42
43    def __post_init__(self) -> None:
44        pass
45
46    def sarif(self) -> sarif.Result:
47        """Returns the SARIF Result representation of this diagnostic."""
48        message = self.message or self.rule.message_default_template
49        if self.additional_messages:
50            additional_message = "\n".join(self.additional_messages)
51            message_markdown = (
52                f"{message}\n\n## Additional Message:\n\n{additional_message}"
53            )
54        else:
55            message_markdown = message
56
57        kind: Literal["informational", "fail"] = (
58            "informational" if self.level == infra.Level.NONE else "fail"
59        )
60
61        sarif_result = sarif.Result(
62            message=sarif.Message(text=message, markdown=message_markdown),
63            level=self.level.name.lower(),  # type: ignore[arg-type]
64            rule_id=self.rule.id,
65            kind=kind,
66        )
67        sarif_result.locations = [location.sarif() for location in self.locations]
68        sarif_result.stacks = [stack.sarif() for stack in self.stacks]
69        sarif_result.graphs = [graph.sarif() for graph in self.graphs]
70        sarif_result.code_flows = [
71            sarif.CodeFlow(
72                thread_flows=[
73                    sarif.ThreadFlow(
74                        locations=[loc.sarif() for loc in self.thread_flow_locations]
75                    )
76                ]
77            )
78        ]
79        sarif_result.properties = sarif.PropertyBag(
80            tags=[tag.value for tag in self.tags]
81        )
82        return sarif_result
83
84    def with_location(self: Self, location: infra.Location) -> Self:
85        """Adds a location to the diagnostic."""
86        self.locations.append(location)
87        return self
88
89    def with_thread_flow_location(
90        self: Self, location: infra.ThreadFlowLocation
91    ) -> Self:
92        """Adds a thread flow location to the diagnostic."""
93        self.thread_flow_locations.append(location)
94        return self
95
96    def with_stack(self: Self, stack: infra.Stack) -> Self:
97        """Adds a stack to the diagnostic."""
98        self.stacks.append(stack)
99        return self
100
101    def with_graph(self: Self, graph: infra.Graph) -> Self:
102        """Adds a graph to the diagnostic."""
103        self.graphs.append(graph)
104        return self
105
106    @contextlib.contextmanager
107    def log_section(
108        self, level: int, message: str, *args, **kwargs
109    ) -> Generator[None, None, None]:
110        """
111        Context manager for a section of log messages, denoted by a title message and increased indentation.
112
113        Same api as `logging.Logger.log`.
114
115        This context manager logs the given title at the specified log level, increases the current
116        section depth for subsequent log messages, and ensures that the section depth is decreased
117        again when exiting the context.
118
119        Args:
120            level: The log level.
121            message: The title message to log.
122            *args: The arguments to the message. Use `LazyString` to defer the
123                expensive evaluation of the arguments until the message is actually logged.
124            **kwargs: The keyword arguments for `logging.Logger.log`.
125
126        Yields:
127            None: This context manager does not yield any value.
128
129        Example:
130            >>> with DiagnosticContext("DummyContext", "1.0"):
131            ...     rule = infra.Rule("RuleID", "DummyRule", "Rule message")
132            ...     diagnostic = Diagnostic(rule, infra.Level.WARNING)
133            ...     with diagnostic.log_section(logging.INFO, "My Section"):
134            ...         diagnostic.log(logging.INFO, "My Message")
135            ...         with diagnostic.log_section(logging.INFO, "My Subsection"):
136            ...             diagnostic.log(logging.INFO, "My Submessage")
137            ...     diagnostic.additional_messages
138            ['## My Section', 'My Message', '### My Subsection', 'My Submessage']
139        """
140        if self.logger.isEnabledFor(level):
141            indented_format_message = (
142                f"##{'#' * self._current_log_section_depth } {message}"
143            )
144            self.log(
145                level,
146                indented_format_message,
147                *args,
148                **kwargs,
149            )
150        self._current_log_section_depth += 1
151        try:
152            yield
153        finally:
154            self._current_log_section_depth -= 1
155
156    def log(self, level: int, message: str, *args, **kwargs) -> None:
157        """Logs a message within the diagnostic. Same api as `logging.Logger.log`.
158
159        If logger is not enabled for the given level, the message will not be logged.
160        Otherwise, the message will be logged and also added to the diagnostic's additional_messages.
161
162        The default setting for `DiagnosticOptions.verbosity_level` is `logging.INFO`. Based on this default,
163        the log level recommendations are as follows. If you've set a different default verbosity level in your
164        application, please adjust accordingly:
165
166        - logging.ERROR: Log any events leading to application failure.
167        - logging.WARNING: Log events that might result in application issues or failures, although not guaranteed.
168        - logging.INFO: Log general useful information, ensuring minimal performance overhead.
169        - logging.DEBUG: Log detailed debug information, which might affect performance when logged.
170
171        Args:
172            level: The log level.
173            message: The message to log.
174            *args: The arguments to the message. Use `LazyString` to defer the
175                expensive evaluation of the arguments until the message is actually logged.
176            **kwargs: The keyword arguments for `logging.Logger.log`.
177        """
178        if self.logger.isEnabledFor(level):
179            formatted_message = message % args
180            self.logger.log(level, formatted_message, **kwargs)
181            self.additional_messages.append(formatted_message)
182
183    def debug(self, message: str, *args, **kwargs) -> None:
184        """Logs a debug message within the diagnostic. Same api as logging.Logger.debug.
185
186        Checkout `log` for more details.
187        """
188        self.log(logging.DEBUG, message, *args, **kwargs)
189
190    def info(self, message: str, *args, **kwargs) -> None:
191        """Logs an info message within the diagnostic. Same api as logging.Logger.info.
192
193        Checkout `log` for more details.
194        """
195        self.log(logging.INFO, message, *args, **kwargs)
196
197    def warning(self, message: str, *args, **kwargs) -> None:
198        """Logs a warning message within the diagnostic. Same api as logging.Logger.warning.
199
200        Checkout `log` for more details.
201        """
202        self.log(logging.WARNING, message, *args, **kwargs)
203
204    def error(self, message: str, *args, **kwargs) -> None:
205        """Logs an error message within the diagnostic. Same api as logging.Logger.error.
206
207        Checkout `log` for more details.
208        """
209        self.log(logging.ERROR, message, *args, **kwargs)
210
211    def log_source_exception(self, level: int, exception: Exception) -> None:
212        """Logs a source exception within the diagnostic.
213
214        Invokes `log_section` and `log` to log the exception in markdown section format.
215        """
216        self.source_exception = exception
217        with self.log_section(level, "Exception log"):
218            self.log(level, "%s", formatter.lazy_format_exception(exception))
219
220    def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack:
221        """Records the current Python call stack."""
222        frames_to_skip += 1  # Skip this function.
223        stack = utils.python_call_stack(frames_to_skip=frames_to_skip)
224        self.with_stack(stack)
225        if len(stack.frames) > 0:
226            self.with_location(stack.frames[0].location)
227        return stack
228
229    def record_python_call(
230        self,
231        fn: Callable,
232        state: Mapping[str, str],
233        message: str | None = None,
234        frames_to_skip: int = 0,
235    ) -> infra.ThreadFlowLocation:
236        """Records a python call as one thread flow step."""
237        frames_to_skip += 1  # Skip this function.
238        stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5)
239        location = utils.function_location(fn)
240        location.message = message
241        # Add function location to the top of the stack.
242        stack.frames.insert(0, infra.StackFrame(location=location))
243        thread_flow_location = infra.ThreadFlowLocation(
244            location=location,
245            state=state,
246            index=len(self.thread_flow_locations),
247            stack=stack,
248        )
249        self.with_thread_flow_location(thread_flow_location)
250        return thread_flow_location
251
252
253class RuntimeErrorWithDiagnostic(RuntimeError):
254    """Runtime error with enclosed diagnostic information."""
255
256    def __init__(self, diagnostic: Diagnostic):
257        super().__init__(diagnostic.message)
258        self.diagnostic = diagnostic
259
260
261@dataclasses.dataclass
262class DiagnosticContext(Generic[_Diagnostic]):
263    name: str
264    version: str
265    options: infra.DiagnosticOptions = dataclasses.field(
266        default_factory=infra.DiagnosticOptions
267    )
268    diagnostics: list[_Diagnostic] = dataclasses.field(init=False, default_factory=list)
269    # TODO(bowbao): Implement this.
270    # _invocation: infra.Invocation = dataclasses.field(init=False)
271    _inflight_diagnostics: list[_Diagnostic] = dataclasses.field(
272        init=False, default_factory=list
273    )
274    _previous_log_level: int = dataclasses.field(init=False, default=logging.WARNING)
275    logger: logging.Logger = dataclasses.field(init=False, default=diagnostic_logger)
276    _bound_diagnostic_type: type = dataclasses.field(init=False, default=Diagnostic)
277
278    def __enter__(self):
279        self._previous_log_level = self.logger.level
280        self.logger.setLevel(self.options.verbosity_level)
281        return self
282
283    def __exit__(self, exc_type, exc_val, exc_tb):
284        self.logger.setLevel(self._previous_log_level)
285        return None
286
287    def sarif(self) -> sarif.Run:
288        """Returns the SARIF Run object."""
289        unique_rules = {diagnostic.rule for diagnostic in self.diagnostics}
290        return sarif.Run(
291            sarif.Tool(
292                driver=sarif.ToolComponent(
293                    name=self.name,
294                    version=self.version,
295                    rules=[rule.sarif() for rule in unique_rules],
296                )
297            ),
298            results=[diagnostic.sarif() for diagnostic in self.diagnostics],
299        )
300
301    def sarif_log(self) -> sarif.SarifLog:  # type: ignore[name-defined]
302        """Returns the SARIF Log object."""
303        return sarif.SarifLog(
304            version=sarif_version.SARIF_VERSION,
305            schema_uri=sarif_version.SARIF_SCHEMA_LINK,
306            runs=[self.sarif()],
307        )
308
309    def to_json(self) -> str:
310        return formatter.sarif_to_json(self.sarif_log())
311
312    def dump(self, file_path: str, compress: bool = False) -> None:
313        """Dumps the SARIF log to a file."""
314        if compress:
315            with gzip.open(file_path, "wt") as f:
316                f.write(self.to_json())
317        else:
318            with open(file_path, "w") as f:
319                f.write(self.to_json())
320
321    def log(self, diagnostic: _Diagnostic) -> None:
322        """Logs a diagnostic.
323
324        This method should be used only after all the necessary information for the diagnostic
325        has been collected.
326
327        Args:
328            diagnostic: The diagnostic to add.
329        """
330        if not isinstance(diagnostic, self._bound_diagnostic_type):
331            raise TypeError(
332                f"Expected diagnostic of type {self._bound_diagnostic_type}, got {type(diagnostic)}"
333            )
334        if self.options.warnings_as_errors and diagnostic.level == infra.Level.WARNING:  # type: ignore[attr-defined]
335            diagnostic.level = infra.Level.ERROR  # type: ignore[attr-defined]
336        self.diagnostics.append(diagnostic)  # type: ignore[arg-type]
337
338    def log_and_raise_if_error(self, diagnostic: _Diagnostic) -> None:
339        """Logs a diagnostic and raises an exception if it is an error.
340
341        Use this method for logging non inflight diagnostics where diagnostic level is not known or
342        lower than ERROR. If it is always expected raise, use `log` and explicit
343        `raise` instead. Otherwise there is no way to convey the message that it always
344        raises to Python intellisense and type checking tools.
345
346        This method should be used only after all the necessary information for the diagnostic
347        has been collected.
348
349        Args:
350            diagnostic: The diagnostic to add.
351        """
352        self.log(diagnostic)
353        if diagnostic.level == infra.Level.ERROR:
354            if diagnostic.source_exception is not None:
355                raise diagnostic.source_exception
356            raise RuntimeErrorWithDiagnostic(diagnostic)
357
358    @contextlib.contextmanager
359    def add_inflight_diagnostic(
360        self, diagnostic: _Diagnostic
361    ) -> Generator[_Diagnostic, None, None]:
362        """Adds a diagnostic to the context.
363
364        Use this method to add diagnostics that are not created by the context.
365        Args:
366            diagnostic: The diagnostic to add.
367        """
368        self._inflight_diagnostics.append(diagnostic)
369        try:
370            yield diagnostic
371        finally:
372            self._inflight_diagnostics.pop()
373
374    def push_inflight_diagnostic(self, diagnostic: _Diagnostic) -> None:
375        """Pushes a diagnostic to the inflight diagnostics stack.
376
377        Args:
378            diagnostic: The diagnostic to push.
379
380        Raises:
381            ValueError: If the rule is not supported by the tool.
382        """
383        self._inflight_diagnostics.append(diagnostic)
384
385    def pop_inflight_diagnostic(self) -> _Diagnostic:
386        """Pops the last diagnostic from the inflight diagnostics stack.
387
388        Returns:
389            The popped diagnostic.
390        """
391        return self._inflight_diagnostics.pop()
392
393    def inflight_diagnostic(self, rule: infra.Rule | None = None) -> _Diagnostic:
394        if rule is None:
395            # TODO(bowbao): Create builtin-rules and create diagnostic using that.
396            if len(self._inflight_diagnostics) <= 0:
397                raise AssertionError("No inflight diagnostics")
398
399            return self._inflight_diagnostics[-1]
400        else:
401            for diagnostic in reversed(self._inflight_diagnostics):
402                if diagnostic.rule == rule:
403                    return diagnostic
404            raise AssertionError(f"No inflight diagnostic for rule {rule.name}")
405