#!/usr/bin/env python3 """ Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations. The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml. Usage: python -m tools.onnx.gen_diagnostics \ torch/onnx/_internal/diagnostics/rules.yaml \ torch/onnx/_internal/diagnostics \ torch/csrc/onnx/diagnostics/generated \ torch/docs/source """ import argparse import os import string import subprocess import textwrap from typing import Any, Mapping, Sequence import yaml from torchgen import utils as torchgen_utils from torchgen.yaml_utils import YamlLoader _RULES_GENERATED_COMMENT = """\ GENERATED CODE - DO NOT EDIT DIRECTLY This file is generated by gen_diagnostics.py. See tools/onnx/gen_diagnostics.py for more information. Diagnostic rules for PyTorch ONNX export. """ _PY_RULE_CLASS_COMMENT = """\ GENERATED CODE - DO NOT EDIT DIRECTLY The purpose of generating a class for each rule is to override the `format_message` method to provide more details in the signature about the format arguments. """ _PY_RULE_CLASS_TEMPLATE = """\ class _{pascal_case_name}(infra.Rule): \"\"\"{short_description}\"\"\" def format_message( # type: ignore[override] self, {message_arguments} ) -> str: \"\"\"Returns the formatted default message of this Rule. Message template: {message_template} \"\"\" return self.message_default_template.format({message_arguments_assigned}) def format( # type: ignore[override] self, level: infra.Level, {message_arguments} ) -> Tuple[infra.Rule, infra.Level, str]: \"\"\"Returns a tuple of (Rule, Level, message) for this Rule. Message template: {message_template} \"\"\" return self, level, self.format_message({message_arguments_assigned}) """ _PY_RULE_COLLECTION_FIELD_TEMPLATE = """\ {snake_case_name}: _{pascal_case_name} = dataclasses.field( default=_{pascal_case_name}.from_sarif(**{sarif_dict}), init=False, ) \"\"\"{short_description}\"\"\" """ _CPP_RULE_TEMPLATE = """\ /** * @brief {short_description} */ {name}, """ _RuleType = Mapping[str, Any] def _kebab_case_to_snake_case(name: str) -> str: return name.replace("-", "_") def _kebab_case_to_pascal_case(name: str) -> str: return "".join(word.capitalize() for word in name.split("-")) def _format_rule_for_python_class(rule: _RuleType) -> str: pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) short_description = rule["short_description"]["text"] message_template = rule["message_strings"]["default"]["text"] field_names = [ field_name for _, field_name, _, _ in string.Formatter().parse(message_template) if field_name is not None ] for field_name in field_names: assert isinstance( field_name, str ), f"Unexpected field type {type(field_name)} from {field_name}. " "Field name must be string.\nFull message template: {message_template}" assert ( not field_name.isnumeric() ), f"Unexpected numeric field name {field_name}. " "Only keyword name formatting is supported.\nFull message template: {message_template}" message_arguments = ", ".join(field_names) message_arguments_assigned = ", ".join( [f"{field_name}={field_name}" for field_name in field_names] ) return _PY_RULE_CLASS_TEMPLATE.format( pascal_case_name=pascal_case_name, short_description=short_description, message_template=repr(message_template), message_arguments=message_arguments, message_arguments_assigned=message_arguments_assigned, ) def _format_rule_for_python_field(rule: _RuleType) -> str: snake_case_name = _kebab_case_to_snake_case(rule["name"]) pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) short_description = rule["short_description"]["text"] return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format( snake_case_name=snake_case_name, pascal_case_name=pascal_case_name, sarif_dict=rule, short_description=short_description, ) def _format_rule_for_cpp(rule: _RuleType) -> str: name = f"k{_kebab_case_to_pascal_case(rule['name'])}" short_description = rule["short_description"]["text"] return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description) def gen_diagnostics_python( rules: Sequence[_RuleType], out_py_dir: str, template_dir: str ) -> None: rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules] rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules] fm = torchgen_utils.FileManager( install_dir=out_py_dir, template_dir=template_dir, dry_run=False ) fm.write_with_template( "_rules.py", "rules.py.in", lambda: { "generated_comment": _RULES_GENERATED_COMMENT, "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT, "rule_classes": "\n".join(rule_class_lines), "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4), }, ) _lint_file(os.path.join(out_py_dir, "_rules.py")) def gen_diagnostics_cpp( rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str ) -> None: rule_lines = [_format_rule_for_cpp(rule) for rule in rules] rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules] fm = torchgen_utils.FileManager( install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False ) fm.write_with_template( "rules.h", "rules.h.in", lambda: { "generated_comment": textwrap.indent( _RULES_GENERATED_COMMENT, " * ", predicate=lambda x: True, # Don't ignore empty line ), "rules": textwrap.indent("\n".join(rule_lines), " " * 2), "py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4), }, ) _lint_file(os.path.join(out_cpp_dir, "rules.h")) def gen_diagnostics_docs( rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str ) -> None: # TODO: Add doc generation in a follow-up PR. pass def _lint_file(file_path: str) -> None: p = subprocess.Popen(["lintrunner", "-a", file_path]) p.wait() def gen_diagnostics( rules_path: str, out_py_dir: str, out_cpp_dir: str, out_docs_dir: str, ) -> None: with open(rules_path) as f: rules = yaml.load(f, Loader=YamlLoader) template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") gen_diagnostics_python( rules, out_py_dir, template_dir, ) gen_diagnostics_cpp( rules, out_cpp_dir, template_dir, ) gen_diagnostics_docs(rules, out_docs_dir, template_dir) def main() -> None: parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files") parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml") parser.add_argument( "out_py_dir", metavar="OUT_PY", help="path to output directory for Python", ) parser.add_argument( "out_cpp_dir", metavar="OUT_CPP", help="path to output directory for C++", ) parser.add_argument( "out_docs_dir", metavar="OUT_DOCS", help="path to output directory for docs", ) args = parser.parse_args() gen_diagnostics( args.rules_path, args.out_py_dir, args.out_cpp_dir, args.out_docs_dir, ) if __name__ == "__main__": main()