xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/clangformat_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import concurrent.futures
5import json
6import logging
7import os
8import subprocess
9import sys
10import time
11from enum import Enum
12from pathlib import Path
13from typing import Any, NamedTuple
14
15
16IS_WINDOWS: bool = os.name == "nt"
17
18
19def eprint(*args: Any, **kwargs: Any) -> None:
20    print(*args, file=sys.stderr, flush=True, **kwargs)
21
22
23class LintSeverity(str, Enum):
24    ERROR = "error"
25    WARNING = "warning"
26    ADVICE = "advice"
27    DISABLED = "disabled"
28
29
30class LintMessage(NamedTuple):
31    path: str | None
32    line: int | None
33    char: int | None
34    code: str
35    severity: LintSeverity
36    name: str
37    original: str | None
38    replacement: str | None
39    description: str | None
40
41
42def as_posix(name: str) -> str:
43    return name.replace("\\", "/") if IS_WINDOWS else name
44
45
46def _run_command(
47    args: list[str],
48    *,
49    timeout: int,
50) -> subprocess.CompletedProcess[bytes]:
51    logging.debug("$ %s", " ".join(args))
52    start_time = time.monotonic()
53    try:
54        return subprocess.run(
55            args,
56            capture_output=True,
57            shell=IS_WINDOWS,  # So batch scripts are found.
58            timeout=timeout,
59            check=True,
60        )
61    finally:
62        end_time = time.monotonic()
63        logging.debug("took %dms", (end_time - start_time) * 1000)
64
65
66def run_command(
67    args: list[str],
68    *,
69    retries: int,
70    timeout: int,
71) -> subprocess.CompletedProcess[bytes]:
72    remaining_retries = retries
73    while True:
74        try:
75            return _run_command(args, timeout=timeout)
76        except subprocess.TimeoutExpired as err:
77            if remaining_retries == 0:
78                raise err
79            remaining_retries -= 1
80            logging.warning(
81                "(%s/%s) Retrying because command failed with: %r",
82                retries - remaining_retries,
83                retries,
84                err,
85            )
86            time.sleep(1)
87
88
89def check_file(
90    filename: str,
91    binary: str,
92    retries: int,
93    timeout: int,
94) -> list[LintMessage]:
95    try:
96        with open(filename, "rb") as f:
97            original = f.read()
98        proc = run_command(
99            [binary, filename],
100            retries=retries,
101            timeout=timeout,
102        )
103    except subprocess.TimeoutExpired:
104        return [
105            LintMessage(
106                path=filename,
107                line=None,
108                char=None,
109                code="CLANGFORMAT",
110                severity=LintSeverity.ERROR,
111                name="timeout",
112                original=None,
113                replacement=None,
114                description=(
115                    "clang-format timed out while trying to process a file. "
116                    "Please report an issue in pytorch/pytorch with the "
117                    "label 'module: lint'"
118                ),
119            )
120        ]
121    except (OSError, subprocess.CalledProcessError) as err:
122        return [
123            LintMessage(
124                path=filename,
125                line=None,
126                char=None,
127                code="CLANGFORMAT",
128                severity=LintSeverity.ADVICE,
129                name="command-failed",
130                original=None,
131                replacement=None,
132                description=(
133                    f"Failed due to {err.__class__.__name__}:\n{err}"
134                    if not isinstance(err, subprocess.CalledProcessError)
135                    else (
136                        "COMMAND (exit code {returncode})\n"
137                        "{command}\n\n"
138                        "STDERR\n{stderr}\n\n"
139                        "STDOUT\n{stdout}"
140                    ).format(
141                        returncode=err.returncode,
142                        command=" ".join(as_posix(x) for x in err.cmd),
143                        stderr=err.stderr.decode("utf-8").strip() or "(empty)",
144                        stdout=err.stdout.decode("utf-8").strip() or "(empty)",
145                    )
146                ),
147            )
148        ]
149
150    replacement = proc.stdout
151    if original == replacement:
152        return []
153
154    return [
155        LintMessage(
156            path=filename,
157            line=None,
158            char=None,
159            code="CLANGFORMAT",
160            severity=LintSeverity.WARNING,
161            name="format",
162            original=original.decode("utf-8"),
163            replacement=replacement.decode("utf-8"),
164            description="See https://clang.llvm.org/docs/ClangFormat.html.\nRun `lintrunner -a` to apply this patch.",
165        )
166    ]
167
168
169def main() -> None:
170    parser = argparse.ArgumentParser(
171        description="Format files with clang-format.",
172        fromfile_prefix_chars="@",
173    )
174    parser.add_argument(
175        "--binary",
176        required=True,
177        help="clang-format binary path",
178    )
179    parser.add_argument(
180        "--retries",
181        default=3,
182        type=int,
183        help="times to retry timed out clang-format",
184    )
185    parser.add_argument(
186        "--timeout",
187        default=90,
188        type=int,
189        help="seconds to wait for clang-format",
190    )
191    parser.add_argument(
192        "--verbose",
193        action="store_true",
194        help="verbose logging",
195    )
196    parser.add_argument(
197        "filenames",
198        nargs="+",
199        help="paths to lint",
200    )
201    args = parser.parse_args()
202
203    logging.basicConfig(
204        format="<%(threadName)s:%(levelname)s> %(message)s",
205        level=logging.NOTSET
206        if args.verbose
207        else logging.DEBUG
208        if len(args.filenames) < 1000
209        else logging.INFO,
210        stream=sys.stderr,
211    )
212
213    binary = os.path.normpath(args.binary) if IS_WINDOWS else args.binary
214    if not Path(binary).exists():
215        lint_message = LintMessage(
216            path=None,
217            line=None,
218            char=None,
219            code="CLANGFORMAT",
220            severity=LintSeverity.ERROR,
221            name="init-error",
222            original=None,
223            replacement=None,
224            description=(
225                f"Could not find clang-format binary at {binary}, "
226                "did you forget to run `lintrunner init`?"
227            ),
228        )
229        print(json.dumps(lint_message._asdict()), flush=True)
230        sys.exit(0)
231
232    with concurrent.futures.ThreadPoolExecutor(
233        max_workers=os.cpu_count(),
234        thread_name_prefix="Thread",
235    ) as executor:
236        futures = {
237            executor.submit(check_file, x, binary, args.retries, args.timeout): x
238            for x in args.filenames
239        }
240        for future in concurrent.futures.as_completed(futures):
241            try:
242                for lint_message in future.result():
243                    print(json.dumps(lint_message._asdict()), flush=True)
244            except Exception:
245                logging.critical('Failed at "%s".', futures[future])
246                raise
247
248
249if __name__ == "__main__":
250    main()
251