xref: /aosp_15_r20/external/pytorch/tools/linter/adapters/black_linter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import concurrent.futures
5import json
6import logging
7import os
8import subprocess
9import sys
10import time
11from enum import Enum
12from typing import Any, BinaryIO, NamedTuple
13
14
15IS_WINDOWS: bool = os.name == "nt"
16
17
18def eprint(*args: Any, **kwargs: Any) -> None:
19    print(*args, file=sys.stderr, flush=True, **kwargs)
20
21
22class LintSeverity(str, Enum):
23    ERROR = "error"
24    WARNING = "warning"
25    ADVICE = "advice"
26    DISABLED = "disabled"
27
28
29class LintMessage(NamedTuple):
30    path: str | None
31    line: int | None
32    char: int | None
33    code: str
34    severity: LintSeverity
35    name: str
36    original: str | None
37    replacement: str | None
38    description: str | None
39
40
41def as_posix(name: str) -> str:
42    return name.replace("\\", "/") if IS_WINDOWS else name
43
44
45def _run_command(
46    args: list[str],
47    *,
48    stdin: BinaryIO,
49    timeout: int,
50) -> subprocess.CompletedProcess[bytes]:
51    logging.debug("$ %s", " ".join(args))
52    start_time = time.monotonic()
53    try:
54        return subprocess.run(
55            args,
56            stdin=stdin,
57            capture_output=True,
58            shell=IS_WINDOWS,  # So batch scripts are found.
59            timeout=timeout,
60            check=True,
61        )
62    finally:
63        end_time = time.monotonic()
64        logging.debug("took %dms", (end_time - start_time) * 1000)
65
66
67def run_command(
68    args: list[str],
69    *,
70    stdin: BinaryIO,
71    retries: int,
72    timeout: int,
73) -> subprocess.CompletedProcess[bytes]:
74    remaining_retries = retries
75    while True:
76        try:
77            return _run_command(args, stdin=stdin, timeout=timeout)
78        except subprocess.TimeoutExpired as err:
79            if remaining_retries == 0:
80                raise err
81            remaining_retries -= 1
82            logging.warning(
83                "(%s/%s) Retrying because command failed with: %r",
84                retries - remaining_retries,
85                retries,
86                err,
87            )
88            time.sleep(1)
89
90
91def check_file(
92    filename: str,
93    retries: int,
94    timeout: int,
95) -> list[LintMessage]:
96    try:
97        with open(filename, "rb") as f:
98            original = f.read()
99        with open(filename, "rb") as f:
100            proc = run_command(
101                [sys.executable, "-mblack", "--stdin-filename", filename, "-"],
102                stdin=f,
103                retries=retries,
104                timeout=timeout,
105            )
106    except subprocess.TimeoutExpired:
107        return [
108            LintMessage(
109                path=filename,
110                line=None,
111                char=None,
112                code="BLACK",
113                severity=LintSeverity.ERROR,
114                name="timeout",
115                original=None,
116                replacement=None,
117                description=(
118                    "black timed out while trying to process a file. "
119                    "Please report an issue in pytorch/pytorch with the "
120                    "label 'module: lint'"
121                ),
122            )
123        ]
124    except (OSError, subprocess.CalledProcessError) as err:
125        return [
126            LintMessage(
127                path=filename,
128                line=None,
129                char=None,
130                code="BLACK",
131                severity=LintSeverity.ADVICE,
132                name="command-failed",
133                original=None,
134                replacement=None,
135                description=(
136                    f"Failed due to {err.__class__.__name__}:\n{err}"
137                    if not isinstance(err, subprocess.CalledProcessError)
138                    else (
139                        "COMMAND (exit code {returncode})\n"
140                        "{command}\n\n"
141                        "STDERR\n{stderr}\n\n"
142                        "STDOUT\n{stdout}"
143                    ).format(
144                        returncode=err.returncode,
145                        command=" ".join(as_posix(x) for x in err.cmd),
146                        stderr=err.stderr.decode("utf-8").strip() or "(empty)",
147                        stdout=err.stdout.decode("utf-8").strip() or "(empty)",
148                    )
149                ),
150            )
151        ]
152
153    replacement = proc.stdout
154    if original == replacement:
155        return []
156
157    return [
158        LintMessage(
159            path=filename,
160            line=None,
161            char=None,
162            code="BLACK",
163            severity=LintSeverity.WARNING,
164            name="format",
165            original=original.decode("utf-8"),
166            replacement=replacement.decode("utf-8"),
167            description="Run `lintrunner -a` to apply this patch.",
168        )
169    ]
170
171
172def main() -> None:
173    parser = argparse.ArgumentParser(
174        description="Format files with black.",
175        fromfile_prefix_chars="@",
176    )
177    parser.add_argument(
178        "--retries",
179        default=3,
180        type=int,
181        help="times to retry timed out black",
182    )
183    parser.add_argument(
184        "--timeout",
185        default=90,
186        type=int,
187        help="seconds to wait for black",
188    )
189    parser.add_argument(
190        "--verbose",
191        action="store_true",
192        help="verbose logging",
193    )
194    parser.add_argument(
195        "filenames",
196        nargs="+",
197        help="paths to lint",
198    )
199    args = parser.parse_args()
200
201    logging.basicConfig(
202        format="<%(threadName)s:%(levelname)s> %(message)s",
203        level=logging.NOTSET
204        if args.verbose
205        else logging.DEBUG
206        if len(args.filenames) < 1000
207        else logging.INFO,
208        stream=sys.stderr,
209    )
210
211    with concurrent.futures.ThreadPoolExecutor(
212        max_workers=os.cpu_count(),
213        thread_name_prefix="Thread",
214    ) as executor:
215        futures = {
216            executor.submit(check_file, x, args.retries, args.timeout): x
217            for x in args.filenames
218        }
219        for future in concurrent.futures.as_completed(futures):
220            try:
221                for lint_message in future.result():
222                    print(json.dumps(lint_message._asdict()), flush=True)
223            except Exception:
224                logging.critical('Failed at "%s".', futures[future])
225                raise
226
227
228if __name__ == "__main__":
229    main()
230