xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/diagnostics/infra/_infra.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""This file defines an additional layer of abstraction on top of the SARIF OM."""
3
4from __future__ import annotations
5
6import dataclasses
7import enum
8import logging
9from typing import Mapping, Sequence
10
11from torch.onnx._internal.diagnostics.infra import formatter, sarif
12
13
14class Level(enum.IntEnum):
15    """The level of a diagnostic.
16
17    This class is used to represent the level of a diagnostic. The levels are defined
18    by the SARIF specification, and are not modifiable. For alternative categories,
19    please use infra.Tag instead. When selecting a level, please consider the following
20    guidelines:
21
22    - NONE: Informational result that does not indicate the presence of a problem.
23    - NOTE: An opportunity for improvement was found.
24    - WARNING: A potential problem was found.
25    - ERROR: A serious problem was found.
26
27    This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer
28    value maps to the logging levels in Python's logging module. The mapping is as
29    follows:
30
31        Level.NONE = logging.DEBUG = 10
32        Level.NOTE = logging.INFO = 20
33        Level.WARNING = logging.WARNING = 30
34        Level.ERROR = logging.ERROR = 40
35    """
36
37    NONE = 10
38    NOTE = 20
39    WARNING = 30
40    ERROR = 40
41
42
43levels = Level
44
45
46class Tag(enum.Enum):
47    """The tag of a diagnostic. This class can be inherited to define custom tags."""
48
49
50class PatchedPropertyBag(sarif.PropertyBag):
51    """Key/value pairs that provide additional information about the object.
52
53    The definition of PropertyBag via SARIF spec is "A property bag is an object (section 3.6)
54    containing an unordered set of properties with arbitrary names." However it is not
55    reflected in the json file, and therefore not captured by the python representation.
56    This patch adds additional **kwargs to the `__init__` method to allow recording
57    arbitrary key/value pairs.
58    """
59
60    def __init__(self, tags: list[str] | None = None, **kwargs):
61        super().__init__(tags=tags)
62        self.__dict__.update(kwargs)
63
64
65@dataclasses.dataclass(frozen=True)
66class Rule:
67    id: str
68    name: str
69    message_default_template: str
70    short_description: str | None = None
71    full_description: str | None = None
72    full_description_markdown: str | None = None
73    help_uri: str | None = None
74
75    @classmethod
76    def from_sarif(cls, **kwargs):
77        """Returns a rule from the SARIF reporting descriptor."""
78        short_description = kwargs.get("short_description", {}).get("text")
79        full_description = kwargs.get("full_description", {}).get("text")
80        full_description_markdown = kwargs.get("full_description", {}).get("markdown")
81        help_uri = kwargs.get("help_uri")
82
83        rule = cls(
84            id=kwargs["id"],
85            name=kwargs["name"],
86            message_default_template=kwargs["message_strings"]["default"]["text"],
87            short_description=short_description,
88            full_description=full_description,
89            full_description_markdown=full_description_markdown,
90            help_uri=help_uri,
91        )
92        return rule
93
94    def sarif(self) -> sarif.ReportingDescriptor:
95        """Returns a SARIF reporting descriptor of this Rule."""
96        short_description = (
97            sarif.MultiformatMessageString(text=self.short_description)
98            if self.short_description is not None
99            else None
100        )
101        full_description = (
102            sarif.MultiformatMessageString(
103                text=self.full_description, markdown=self.full_description_markdown
104            )
105            if self.full_description is not None
106            else None
107        )
108        return sarif.ReportingDescriptor(
109            id=self.id,
110            name=self.name,
111            short_description=short_description,
112            full_description=full_description,
113            help_uri=self.help_uri,
114        )
115
116    def format(self, level: Level, *args, **kwargs) -> tuple[Rule, Level, str]:
117        """Returns a tuple of (rule, level, message) for a diagnostic.
118
119        This method is used to format the message of a diagnostic. The message is
120        formatted using the default template of this rule, and the arguments passed in
121        as `*args` and `**kwargs`. The level is used to override the default level of
122        this rule.
123        """
124        return (self, level, self.format_message(*args, **kwargs))
125
126    def format_message(self, *args, **kwargs) -> str:
127        """Returns the formatted default message of this Rule.
128
129        This method should be overridden (with code generation) by subclasses to reflect
130        the exact arguments needed by the message template. This is a helper method to
131        create the default message for a diagnostic.
132        """
133        return self.message_default_template.format(*args, **kwargs)
134
135
136@dataclasses.dataclass
137class Location:
138    uri: str | None = None
139    line: int | None = None
140    message: str | None = None
141    start_column: int | None = None
142    end_column: int | None = None
143    snippet: str | None = None
144    function: str | None = None
145
146    def sarif(self) -> sarif.Location:
147        """Returns the SARIF representation of this location."""
148        return sarif.Location(
149            physical_location=sarif.PhysicalLocation(
150                artifact_location=sarif.ArtifactLocation(uri=self.uri),
151                region=sarif.Region(
152                    start_line=self.line,
153                    start_column=self.start_column,
154                    end_column=self.end_column,
155                    snippet=sarif.ArtifactContent(text=self.snippet),
156                ),
157            ),
158            message=sarif.Message(text=self.message)
159            if self.message is not None
160            else None,
161        )
162
163
164@dataclasses.dataclass
165class StackFrame:
166    location: Location
167
168    def sarif(self) -> sarif.StackFrame:
169        """Returns the SARIF representation of this stack frame."""
170        return sarif.StackFrame(location=self.location.sarif())
171
172
173@dataclasses.dataclass
174class Stack:
175    """Records a stack trace. The frames are in order from newest to oldest stack frame."""
176
177    frames: list[StackFrame] = dataclasses.field(default_factory=list)
178    message: str | None = None
179
180    def sarif(self) -> sarif.Stack:
181        """Returns the SARIF representation of this stack."""
182        return sarif.Stack(
183            frames=[frame.sarif() for frame in self.frames],
184            message=sarif.Message(text=self.message)
185            if self.message is not None
186            else None,
187        )
188
189
190@dataclasses.dataclass
191class ThreadFlowLocation:
192    """Records code location and the initial state."""
193
194    location: Location
195    state: Mapping[str, str]
196    index: int
197    stack: Stack | None = None
198
199    def sarif(self) -> sarif.ThreadFlowLocation:
200        """Returns the SARIF representation of this thread flow location."""
201        return sarif.ThreadFlowLocation(
202            location=self.location.sarif(),
203            state=self.state,
204            stack=self.stack.sarif() if self.stack is not None else None,
205        )
206
207
208@dataclasses.dataclass
209class Graph:
210    """A graph of diagnostics.
211
212    This class stores the string representation of a model graph.
213    The `nodes` and `edges` fields are unused in the current implementation.
214    """
215
216    graph: str
217    name: str
218    description: str | None = None
219
220    def sarif(self) -> sarif.Graph:
221        """Returns the SARIF representation of this graph."""
222        return sarif.Graph(
223            description=sarif.Message(text=self.graph),
224            properties=PatchedPropertyBag(name=self.name, description=self.description),
225        )
226
227
228@dataclasses.dataclass
229class RuleCollection:
230    _rule_id_name_set: frozenset[tuple[str, str]] = dataclasses.field(init=False)
231
232    def __post_init__(self) -> None:
233        self._rule_id_name_set = frozenset(
234            {
235                (field.default.id, field.default.name)
236                for field in dataclasses.fields(self)
237                if isinstance(field.default, Rule)
238            }
239        )
240
241    def __contains__(self, rule: Rule) -> bool:
242        """Checks if the rule is in the collection."""
243        return (rule.id, rule.name) in self._rule_id_name_set
244
245    @classmethod
246    def custom_collection_from_list(
247        cls, new_collection_class_name: str, rules: Sequence[Rule]
248    ) -> RuleCollection:
249        """Creates a custom class inherited from RuleCollection with the list of rules."""
250        return dataclasses.make_dataclass(
251            new_collection_class_name,
252            [
253                (
254                    formatter.kebab_case_to_snake_case(rule.name),
255                    type(rule),
256                    dataclasses.field(default=rule),
257                )
258                for rule in rules
259            ],
260            bases=(cls,),
261        )()
262
263
264class Invocation:
265    # TODO: Implement this.
266    # Tracks top level call arguments and diagnostic options.
267    def __init__(self) -> None:
268        raise NotImplementedError
269
270
271@dataclasses.dataclass
272class DiagnosticOptions:
273    """Options for diagnostic context.
274
275    Attributes:
276        verbosity_level: Set the amount of information logged for each diagnostics,
277            equivalent to the 'level' in Python logging module.
278        warnings_as_errors: When True, warning diagnostics are treated as error diagnostics.
279    """
280
281    verbosity_level: int = dataclasses.field(default=logging.INFO)
282    """Set the amount of information logged for each diagnostics, equivalent to the 'level' in Python logging module."""
283
284    warnings_as_errors: bool = dataclasses.field(default=False)
285    """If True, warning diagnostics are treated as error diagnostics."""
286