xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/exporter/_reporting.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import dataclasses
5import re
6from typing import TYPE_CHECKING
7
8from torch.onnx._internal.exporter import _analysis, _registration, _verification
9
10
11if TYPE_CHECKING:
12    import os
13
14    from onnxscript import ir
15
16    import torch
17
18
19@dataclasses.dataclass
20class ExportStatus:
21    # Whether torch.export.export.export() succeeds
22    torch_export: bool | None = None
23    # Whether torch.export.export.export(..., strict=False) succeeds
24    torch_export_non_strict: bool | None = None
25    # Whether torch.jit.trace succeeds
26    torch_jit: bool | None = None
27    # Whether ONNX translation succeeds
28    onnx_translation: bool | None = None
29    # Whether ONNX model passes onnx.checker.check_model
30    onnx_checker: bool | None = None
31    # Whether ONNX model runs successfully with ONNX Runtime
32    onnx_runtime: bool | None = None
33    # Whether the output of the ONNX model is accurate
34    output_accuracy: bool | None = None
35
36
37def _status_emoji(status: bool | None) -> str:
38    if status is None:
39        return "⚪"
40    return "✅" if status else "❌"
41
42
43def _format_export_status(status: ExportStatus) -> str:
44    return (
45        f"```\n"
46        f"{_status_emoji(status.torch_export)} Obtain model graph with `torch.export.export`\n"
47        f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n"
48        f"{_status_emoji(status.torch_jit)} Obtain model graph with `torch.jit.trace`\n"
49        f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n"
50        f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n"
51        f"{_status_emoji(status.onnx_runtime)} Execute the model with ONNX Runtime\n"
52        f"{_status_emoji(status.output_accuracy)} Validate model output accuracy\n"
53        f"```\n\n"
54    )
55
56
57def _strip_color_from_string(text: str) -> str:
58    # This regular expression matches ANSI escape codes
59    # https://github.com/pytorch/pytorch/blob/9554a9af8788c57e1c5222c39076a5afcf0998ae/torch/_dynamo/utils.py#L2785-L2788
60    ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]")
61    return ansi_escape.sub("", text)
62
63
64def _format_exported_program(exported_program: torch.export.ExportedProgram) -> str:
65    # Adapted from https://github.com/pytorch/pytorch/pull/128476
66    # to remove colors
67    # Even though we can call graph_module.print_readable directly, since the
68    # colored option was added only recently, we can't guarantee that the
69    # version of PyTorch used by the user has this option. Therefore, we
70    # still call str(ExportedProgram)
71    text = f"```python\n{_strip_color_from_string(str(exported_program))}\n```\n\n"
72    return text
73
74
75def construct_report_file_name(timestamp: str, status: ExportStatus) -> str:
76    # Status could be None. So we need to check for False explicitly.
77    if not (status.torch_export or status.torch_export_non_strict or status.torch_jit):
78        # All strategies failed
79        postfix = "pt_export"
80    elif status.onnx_translation is False:
81        postfix = "conversion"
82    elif status.onnx_checker is False:
83        postfix = "checker"
84    elif status.onnx_runtime is False:
85        postfix = "runtime"
86    elif status.output_accuracy is False:
87        postfix = "accuracy"
88    elif status.torch_export is False or status.torch_export_non_strict is False:
89        # Some strategies failed
90        postfix = "strategies"
91    else:
92        postfix = "success"
93    return f"onnx_export_{timestamp}_{postfix}.md"
94
95
96def format_decomp_comparison(
97    pre_decomp_unique_ops: set[str],
98    post_decomp_unique_ops: set[str],
99) -> str:
100    """Format the decomposition comparison result.
101
102    Args:
103        unique_ops_in_a: The unique ops in the first program.
104        unique_ops_in_b: The unique ops in the second program.
105
106    Returns:
107        The formatted comparison result.
108    """
109    return (
110        f"Ops exist only in the ExportedProgram before decomposition: `{sorted(pre_decomp_unique_ops)}`\n\n"
111        f"Ops exist only in the ExportedProgram after decomposition: `{sorted(post_decomp_unique_ops)}`\n"
112    )
113
114
115def format_verification_infos(
116    verification_infos: list[_verification.VerificationInfo],
117) -> str:
118    """Format the verification result.
119
120    Args:
121        verification_infos: The verification result.
122
123    Returns:
124        The formatted verification result.
125    """
126    return "\n".join(
127        f"`{info.name}`: `max_abs_diff={info.max_abs_diff:e}`, `max_rel_diff={info.max_rel_diff:e}`, "
128        f"`abs_diff_hist={info.abs_diff_hist}`, `rel_diff_hist={info.rel_diff_hist}`"
129        for info in verification_infos
130    )
131
132
133def create_torch_export_error_report(
134    filename: str | os.PathLike,
135    formatted_traceback: str,
136    *,
137    export_status: ExportStatus,
138    profile_result: str | None,
139):
140    with open(filename, "w", encoding="utf-8") as f:
141        f.write("# PyTorch ONNX Conversion Error Report\n\n")
142        f.write(_format_export_status(export_status))
143        f.write("Error message:\n\n")
144        f.write("```pytb\n")
145        f.write(formatted_traceback)
146        f.write("```\n\n")
147        if profile_result is not None:
148            f.write("## Profiling result\n\n")
149            f.write("```\n")
150            f.write(profile_result)
151            f.write("```\n")
152
153
154def create_onnx_export_report(
155    filename: str | os.PathLike,
156    formatted_traceback: str,
157    program: torch.export.ExportedProgram,
158    *,
159    decomp_comparison: str | None = None,
160    export_status: ExportStatus,
161    profile_result: str | None,
162    model: ir.Model | None = None,
163    registry: _registration.ONNXRegistry | None = None,
164    verification_result: str | None = None,
165):
166    with open(filename, "w", encoding="utf-8") as f:
167        f.write("# PyTorch ONNX Conversion Report\n\n")
168        f.write(_format_export_status(export_status))
169        f.write("## Error messages\n\n")
170        f.write("```pytb\n")
171        f.write(formatted_traceback)
172        f.write("\n```\n\n")
173        f.write("## Exported program\n\n")
174        f.write(_format_exported_program(program))
175        if model is not None:
176            f.write("## ONNX model\n\n")
177            f.write("```python\n")
178            f.write(str(model))
179            f.write("\n```\n\n")
180        f.write("## Analysis\n\n")
181        _analysis.analyze(program, file=f, registry=registry)
182        if decomp_comparison is not None:
183            f.write("\n## Decomposition comparison\n\n")
184            f.write(decomp_comparison)
185            f.write("\n")
186        if verification_result is not None:
187            f.write("\n## Verification results\n\n")
188            f.write(verification_result)
189            f.write("\n")
190        if profile_result is not None:
191            f.write("\n## Profiling result\n\n")
192            f.write("```\n")
193            f.write(profile_result)
194            f.write("```\n")
195