1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import dataclasses 5import re 6from typing import TYPE_CHECKING 7 8from torch.onnx._internal.exporter import _analysis, _registration, _verification 9 10 11if TYPE_CHECKING: 12 import os 13 14 from onnxscript import ir 15 16 import torch 17 18 19@dataclasses.dataclass 20class ExportStatus: 21 # Whether torch.export.export.export() succeeds 22 torch_export: bool | None = None 23 # Whether torch.export.export.export(..., strict=False) succeeds 24 torch_export_non_strict: bool | None = None 25 # Whether torch.jit.trace succeeds 26 torch_jit: bool | None = None 27 # Whether ONNX translation succeeds 28 onnx_translation: bool | None = None 29 # Whether ONNX model passes onnx.checker.check_model 30 onnx_checker: bool | None = None 31 # Whether ONNX model runs successfully with ONNX Runtime 32 onnx_runtime: bool | None = None 33 # Whether the output of the ONNX model is accurate 34 output_accuracy: bool | None = None 35 36 37def _status_emoji(status: bool | None) -> str: 38 if status is None: 39 return "⚪" 40 return "✅" if status else "❌" 41 42 43def _format_export_status(status: ExportStatus) -> str: 44 return ( 45 f"```\n" 46 f"{_status_emoji(status.torch_export)} Obtain model graph with `torch.export.export`\n" 47 f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n" 48 f"{_status_emoji(status.torch_jit)} Obtain model graph with `torch.jit.trace`\n" 49 f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n" 50 f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n" 51 f"{_status_emoji(status.onnx_runtime)} Execute the model with ONNX Runtime\n" 52 f"{_status_emoji(status.output_accuracy)} Validate model output accuracy\n" 53 f"```\n\n" 54 ) 55 56 57def _strip_color_from_string(text: str) -> str: 58 # This regular expression matches ANSI escape codes 59 # https://github.com/pytorch/pytorch/blob/9554a9af8788c57e1c5222c39076a5afcf0998ae/torch/_dynamo/utils.py#L2785-L2788 60 ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") 61 return ansi_escape.sub("", text) 62 63 64def _format_exported_program(exported_program: torch.export.ExportedProgram) -> str: 65 # Adapted from https://github.com/pytorch/pytorch/pull/128476 66 # to remove colors 67 # Even though we can call graph_module.print_readable directly, since the 68 # colored option was added only recently, we can't guarantee that the 69 # version of PyTorch used by the user has this option. Therefore, we 70 # still call str(ExportedProgram) 71 text = f"```python\n{_strip_color_from_string(str(exported_program))}\n```\n\n" 72 return text 73 74 75def construct_report_file_name(timestamp: str, status: ExportStatus) -> str: 76 # Status could be None. So we need to check for False explicitly. 77 if not (status.torch_export or status.torch_export_non_strict or status.torch_jit): 78 # All strategies failed 79 postfix = "pt_export" 80 elif status.onnx_translation is False: 81 postfix = "conversion" 82 elif status.onnx_checker is False: 83 postfix = "checker" 84 elif status.onnx_runtime is False: 85 postfix = "runtime" 86 elif status.output_accuracy is False: 87 postfix = "accuracy" 88 elif status.torch_export is False or status.torch_export_non_strict is False: 89 # Some strategies failed 90 postfix = "strategies" 91 else: 92 postfix = "success" 93 return f"onnx_export_{timestamp}_{postfix}.md" 94 95 96def format_decomp_comparison( 97 pre_decomp_unique_ops: set[str], 98 post_decomp_unique_ops: set[str], 99) -> str: 100 """Format the decomposition comparison result. 101 102 Args: 103 unique_ops_in_a: The unique ops in the first program. 104 unique_ops_in_b: The unique ops in the second program. 105 106 Returns: 107 The formatted comparison result. 108 """ 109 return ( 110 f"Ops exist only in the ExportedProgram before decomposition: `{sorted(pre_decomp_unique_ops)}`\n\n" 111 f"Ops exist only in the ExportedProgram after decomposition: `{sorted(post_decomp_unique_ops)}`\n" 112 ) 113 114 115def format_verification_infos( 116 verification_infos: list[_verification.VerificationInfo], 117) -> str: 118 """Format the verification result. 119 120 Args: 121 verification_infos: The verification result. 122 123 Returns: 124 The formatted verification result. 125 """ 126 return "\n".join( 127 f"`{info.name}`: `max_abs_diff={info.max_abs_diff:e}`, `max_rel_diff={info.max_rel_diff:e}`, " 128 f"`abs_diff_hist={info.abs_diff_hist}`, `rel_diff_hist={info.rel_diff_hist}`" 129 for info in verification_infos 130 ) 131 132 133def create_torch_export_error_report( 134 filename: str | os.PathLike, 135 formatted_traceback: str, 136 *, 137 export_status: ExportStatus, 138 profile_result: str | None, 139): 140 with open(filename, "w", encoding="utf-8") as f: 141 f.write("# PyTorch ONNX Conversion Error Report\n\n") 142 f.write(_format_export_status(export_status)) 143 f.write("Error message:\n\n") 144 f.write("```pytb\n") 145 f.write(formatted_traceback) 146 f.write("```\n\n") 147 if profile_result is not None: 148 f.write("## Profiling result\n\n") 149 f.write("```\n") 150 f.write(profile_result) 151 f.write("```\n") 152 153 154def create_onnx_export_report( 155 filename: str | os.PathLike, 156 formatted_traceback: str, 157 program: torch.export.ExportedProgram, 158 *, 159 decomp_comparison: str | None = None, 160 export_status: ExportStatus, 161 profile_result: str | None, 162 model: ir.Model | None = None, 163 registry: _registration.ONNXRegistry | None = None, 164 verification_result: str | None = None, 165): 166 with open(filename, "w", encoding="utf-8") as f: 167 f.write("# PyTorch ONNX Conversion Report\n\n") 168 f.write(_format_export_status(export_status)) 169 f.write("## Error messages\n\n") 170 f.write("```pytb\n") 171 f.write(formatted_traceback) 172 f.write("\n```\n\n") 173 f.write("## Exported program\n\n") 174 f.write(_format_exported_program(program)) 175 if model is not None: 176 f.write("## ONNX model\n\n") 177 f.write("```python\n") 178 f.write(str(model)) 179 f.write("\n```\n\n") 180 f.write("## Analysis\n\n") 181 _analysis.analyze(program, file=f, registry=registry) 182 if decomp_comparison is not None: 183 f.write("\n## Decomposition comparison\n\n") 184 f.write(decomp_comparison) 185 f.write("\n") 186 if verification_result is not None: 187 f.write("\n## Verification results\n\n") 188 f.write(verification_result) 189 f.write("\n") 190 if profile_result is not None: 191 f.write("\n## Profiling result\n\n") 192 f.write("```\n") 193 f.write(profile_result) 194 f.write("```\n") 195