1from __future__ import annotations 2 3import argparse 4import concurrent.futures 5import json 6import logging 7import os 8import subprocess 9import sys 10import time 11from enum import Enum 12from typing import Any, BinaryIO, NamedTuple 13 14 15IS_WINDOWS: bool = os.name == "nt" 16 17 18def eprint(*args: Any, **kwargs: Any) -> None: 19 print(*args, file=sys.stderr, flush=True, **kwargs) 20 21 22class LintSeverity(str, Enum): 23 ERROR = "error" 24 WARNING = "warning" 25 ADVICE = "advice" 26 DISABLED = "disabled" 27 28 29class LintMessage(NamedTuple): 30 path: str | None 31 line: int | None 32 char: int | None 33 code: str 34 severity: LintSeverity 35 name: str 36 original: str | None 37 replacement: str | None 38 description: str | None 39 40 41def as_posix(name: str) -> str: 42 return name.replace("\\", "/") if IS_WINDOWS else name 43 44 45def _run_command( 46 args: list[str], 47 *, 48 stdin: BinaryIO, 49 timeout: int, 50) -> subprocess.CompletedProcess[bytes]: 51 logging.debug("$ %s", " ".join(args)) 52 start_time = time.monotonic() 53 try: 54 return subprocess.run( 55 args, 56 stdin=stdin, 57 capture_output=True, 58 shell=IS_WINDOWS, # So batch scripts are found. 59 timeout=timeout, 60 check=True, 61 ) 62 finally: 63 end_time = time.monotonic() 64 logging.debug("took %dms", (end_time - start_time) * 1000) 65 66 67def run_command( 68 args: list[str], 69 *, 70 stdin: BinaryIO, 71 retries: int, 72 timeout: int, 73) -> subprocess.CompletedProcess[bytes]: 74 remaining_retries = retries 75 while True: 76 try: 77 return _run_command(args, stdin=stdin, timeout=timeout) 78 except subprocess.TimeoutExpired as err: 79 if remaining_retries == 0: 80 raise err 81 remaining_retries -= 1 82 logging.warning( 83 "(%s/%s) Retrying because command failed with: %r", 84 retries - remaining_retries, 85 retries, 86 err, 87 ) 88 time.sleep(1) 89 90 91def check_file( 92 filename: str, 93 retries: int, 94 timeout: int, 95) -> list[LintMessage]: 96 try: 97 with open(filename, "rb") as f: 98 original = f.read() 99 with open(filename, "rb") as f: 100 proc = run_command( 101 [sys.executable, "-mblack", "--stdin-filename", filename, "-"], 102 stdin=f, 103 retries=retries, 104 timeout=timeout, 105 ) 106 except subprocess.TimeoutExpired: 107 return [ 108 LintMessage( 109 path=filename, 110 line=None, 111 char=None, 112 code="BLACK", 113 severity=LintSeverity.ERROR, 114 name="timeout", 115 original=None, 116 replacement=None, 117 description=( 118 "black timed out while trying to process a file. " 119 "Please report an issue in pytorch/pytorch with the " 120 "label 'module: lint'" 121 ), 122 ) 123 ] 124 except (OSError, subprocess.CalledProcessError) as err: 125 return [ 126 LintMessage( 127 path=filename, 128 line=None, 129 char=None, 130 code="BLACK", 131 severity=LintSeverity.ADVICE, 132 name="command-failed", 133 original=None, 134 replacement=None, 135 description=( 136 f"Failed due to {err.__class__.__name__}:\n{err}" 137 if not isinstance(err, subprocess.CalledProcessError) 138 else ( 139 "COMMAND (exit code {returncode})\n" 140 "{command}\n\n" 141 "STDERR\n{stderr}\n\n" 142 "STDOUT\n{stdout}" 143 ).format( 144 returncode=err.returncode, 145 command=" ".join(as_posix(x) for x in err.cmd), 146 stderr=err.stderr.decode("utf-8").strip() or "(empty)", 147 stdout=err.stdout.decode("utf-8").strip() or "(empty)", 148 ) 149 ), 150 ) 151 ] 152 153 replacement = proc.stdout 154 if original == replacement: 155 return [] 156 157 return [ 158 LintMessage( 159 path=filename, 160 line=None, 161 char=None, 162 code="BLACK", 163 severity=LintSeverity.WARNING, 164 name="format", 165 original=original.decode("utf-8"), 166 replacement=replacement.decode("utf-8"), 167 description="Run `lintrunner -a` to apply this patch.", 168 ) 169 ] 170 171 172def main() -> None: 173 parser = argparse.ArgumentParser( 174 description="Format files with black.", 175 fromfile_prefix_chars="@", 176 ) 177 parser.add_argument( 178 "--retries", 179 default=3, 180 type=int, 181 help="times to retry timed out black", 182 ) 183 parser.add_argument( 184 "--timeout", 185 default=90, 186 type=int, 187 help="seconds to wait for black", 188 ) 189 parser.add_argument( 190 "--verbose", 191 action="store_true", 192 help="verbose logging", 193 ) 194 parser.add_argument( 195 "filenames", 196 nargs="+", 197 help="paths to lint", 198 ) 199 args = parser.parse_args() 200 201 logging.basicConfig( 202 format="<%(threadName)s:%(levelname)s> %(message)s", 203 level=logging.NOTSET 204 if args.verbose 205 else logging.DEBUG 206 if len(args.filenames) < 1000 207 else logging.INFO, 208 stream=sys.stderr, 209 ) 210 211 with concurrent.futures.ThreadPoolExecutor( 212 max_workers=os.cpu_count(), 213 thread_name_prefix="Thread", 214 ) as executor: 215 futures = { 216 executor.submit(check_file, x, args.retries, args.timeout): x 217 for x in args.filenames 218 } 219 for future in concurrent.futures.as_completed(futures): 220 try: 221 for lint_message in future.result(): 222 print(json.dumps(lint_message._asdict()), flush=True) 223 except Exception: 224 logging.critical('Failed at "%s".', futures[future]) 225 raise 226 227 228if __name__ == "__main__": 229 main() 230