xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/constexpr_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues
3"""
4
5from __future__ import annotations
6
7import argparse
8import json
9import logging
10import sys
11from enum import Enum
12from typing import NamedTuple
13
14
15CONSTEXPR = "constexpr char"
16CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char"
17
18LINTER_CODE = "CONSTEXPR"
19
20
21class LintSeverity(str, Enum):
22    ERROR = "error"
23
24
25class LintMessage(NamedTuple):
26    path: str | None
27    line: int | None
28    char: int | None
29    code: str
30    severity: LintSeverity
31    name: str
32    original: str | None
33    replacement: str | None
34    description: str | None
35
36
37def check_file(filename: str) -> LintMessage | None:
38    logging.debug("Checking file %s", filename)
39
40    with open(filename) as f:
41        lines = f.readlines()
42
43    for idx, line in enumerate(lines):
44        if CONSTEXPR in line:
45            original = "".join(lines)
46            replacement = original.replace(CONSTEXPR, CONSTEXPR_MACRO)
47            logging.debug("replacement: %s", replacement)
48            return LintMessage(
49                path=filename,
50                line=idx,
51                char=None,
52                code=LINTER_CODE,
53                severity=LintSeverity.ERROR,
54                name="Vanilla constexpr used, prefer macros",
55                original=original,
56                replacement=replacement,
57                description="Vanilla constexpr used, prefer macros run `lintrunner --take CONSTEXPR -a` to apply changes.",
58            )
59    return None
60
61
62if __name__ == "__main__":
63    parser = argparse.ArgumentParser(
64        description="CONSTEXPR linter",
65        fromfile_prefix_chars="@",
66    )
67    parser.add_argument(
68        "--verbose",
69        action="store_true",
70    )
71    parser.add_argument(
72        "filenames",
73        nargs="+",
74        help="paths to lint",
75    )
76
77    args = parser.parse_args()
78
79    logging.basicConfig(
80        format="<%(threadName)s:%(levelname)s> %(message)s",
81        level=logging.NOTSET
82        if args.verbose
83        else logging.DEBUG
84        if len(args.filenames) < 1000
85        else logging.INFO,
86        stream=sys.stderr,
87    )
88
89    lint_messages = []
90    for filename in args.filenames:
91        lint_message = check_file(filename)
92        if lint_message is not None:
93            lint_messages.append(lint_message)
94
95    for lint_message in lint_messages:
96        print(json.dumps(lint_message._asdict()), flush=True)
97