xref: /aosp_15_r20/external/pytorch/tools/iwyu/fixup.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import re
2import sys
3
4
5QUOTE_INCLUDE_RE = re.compile(r'^#include "(.*)"')
6ANGLE_INCLUDE_RE = re.compile(r"^#include <(.*)>")
7
8# By default iwyu will pick the C include, but we prefer the C++ headers
9STD_C_HEADER_MAP = {
10    "<assert.h>": "<cassert>",
11    "<complex.h>": "<ccomplex>",
12    "<ctype.h>": "<cctype>",
13    "<errno.h>": "<cerrno>",
14    "<fenv.h>": "<cfenv>",
15    "<float.h>": "<cfloat>",
16    "<inttypes.h>": "<cinttypes>",
17    "<iso646.h>": "<ciso646>",
18    "<limits.h>": "<climits>",
19    "<locale.h>": "<clocale>",
20    "<math.h>": "<cmath>",
21    "<setjmp.h>": "<csetjmp>",
22    "<signal.h>": "<csignal>",
23    "<stdalign.h>": "<cstdalign>",
24    "<stdarg.h>": "<cstdarg>",
25    "<stdbool.h>": "<cstdbool>",
26    "<stddef.h>": "<cstddef>",
27    "<stdint.h>": "<cstdint>",
28    "<stdio.h>": "<cstdio>",
29    "<stdlib.h>": "<cstdlib>",
30    "<string.h>": "<cstring>",
31    "<tgmath.h>": "<ctgmath>",
32    "<time.h>": "<ctime>",
33    "<uchar.h>": "<cuchar>",
34    "<wchar.h>": "<cwchar>",
35    "<wctype.h>": "<cwctype>",
36}
37
38
39def main() -> None:
40    for line in sys.stdin:
41        # Convert all quoted includes to angle brackets
42        match = QUOTE_INCLUDE_RE.match(line)
43        if match is not None:
44            print(f"#include <{match.group(1)}>{line[match.end(0):]}", end="")
45            continue
46
47        match = ANGLE_INCLUDE_RE.match(line)
48        if match is not None:
49            path = f"<{match.group(1)}>"
50            new_path = STD_C_HEADER_MAP.get(path, path)
51            tail = line[match.end(0) :]
52            if len(tail) > 1:
53                tail = " " + tail
54            print(f"#include {new_path}{tail}", end="")
55            continue
56
57        print(line, end="")
58
59
60if __name__ == "__main__":
61    main()
62