xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/check_kernel_launches.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import os
4import re
5import sys
6from typing import List
7
8__all__ = [
9    "check_code_for_cuda_kernel_launches",
10    "check_cuda_kernel_launches",
11]
12
13# FILES TO EXCLUDE (match is done with suffix using `endswith`)
14# You wouldn't drive without a seatbelt, though, so why would you
15# launch a kernel without some safety? Use this as a quick workaround
16# for a problem with the checker, fix the checker, then de-exclude
17# the files in question.
18exclude_files: List[str] = []
19
20# Without using a C++ AST we can't 100% detect kernel launches, so we
21# model them as having the pattern "<<<parameters>>>(arguments);"
22# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be
23# the next statement.
24#
25# We model the next statement as ending at the next `}` or `;`.
26# If we see `}` then a clause ended (bad) if we see a semi-colon then
27# we expect the launch check just before it.
28#
29# Since the kernel launch can include lambda statements, it's important
30# to find the correct end-paren of the kernel launch. Doing this with
31# pure regex requires recursive regex, which aren't part of the Python
32# standard library. To avoid an additional dependency, we build a prefix
33# regex that finds the start of a kernel launch, use a paren-matching
34# algorithm to find the end of the launch, and then another regex to
35# determine if a launch check is present.
36
37# Finds potential starts of kernel launches
38kernel_launch_start = re.compile(
39    r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE
40)
41
42# This pattern should start at the character after the final paren of the
43# kernel launch. It returns a match if the launch check is not the next statement
44has_check = re.compile(
45    r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE
46)
47
48def find_matching_paren(s: str, startpos: int) -> int:
49    """Given a string "prefix (unknown number of characters) suffix"
50    and the position of the first `(` returns the index of the character
51    1 past the `)`, accounting for paren nesting
52    """
53    opening = 0
54    for i, c in enumerate(s[startpos:]):
55        if c == '(':
56            opening += 1
57        elif c == ')':
58            opening -= 1
59            if opening == 0:
60                return startpos + i + 1
61
62    raise IndexError("Closing parens not found!")
63
64
65def should_exclude_file(filename) -> bool:
66    for exclude_suffix in exclude_files:
67        if filename.endswith(exclude_suffix):
68            return True
69    return False
70
71
72def check_code_for_cuda_kernel_launches(code, filename=None):
73    """Checks code for CUDA kernel launches without cuda error checks.
74
75    Args:
76        filename - Filename of file containing the code. Used only for display
77                   purposes, so you can put anything here.
78        code     - The code to check
79
80    Returns:
81        The number of unsafe kernel launches in the code
82    """
83    if filename is None:
84        filename = "##Python Function Call##"
85
86    # We break the code apart and put it back together to add
87    # helpful line numberings for identifying problem areas
88    code = enumerate(code.split("\n"))                             # Split by line breaks
89    code = [f"{lineno}: {linecode}" for lineno, linecode in code]  # Number the lines
90    code = '\n'.join(code)                                         # Put it back together
91
92    num_launches_without_checks = 0
93    for m in kernel_launch_start.finditer(code):
94        end_paren = find_matching_paren(code, m.end() - 1)
95        if has_check.match(code, end_paren):
96            num_launches_without_checks += 1
97            context = code[m.start():end_paren + 1]
98            print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr)
99
100    return num_launches_without_checks
101
102
103def check_file(filename):
104    """Checks a file for CUDA kernel launches without cuda error checks
105
106    Args:
107        filename - File to check
108
109    Returns:
110        The number of unsafe kernel launches in the file
111    """
112    if not (filename.endswith((".cu", ".cuh"))):
113        return 0
114    if should_exclude_file(filename):
115        return 0
116    with open(filename) as fo:
117        contents = fo.read()
118        unsafeCount = check_code_for_cuda_kernel_launches(contents, filename)
119    return unsafeCount
120
121
122def check_cuda_kernel_launches():
123    """Checks all pytorch code for CUDA kernel launches without cuda error checks
124
125    Returns:
126        The number of unsafe kernel launches in the codebase
127    """
128    torch_dir = os.path.dirname(os.path.realpath(__file__))
129    torch_dir = os.path.dirname(torch_dir)  # Go up to parent torch
130    torch_dir = os.path.dirname(torch_dir)  # Go up to parent caffe2
131
132    kernels_without_checks = 0
133    files_without_checks = []
134    for root, dirnames, filenames in os.walk(torch_dir):
135        # `$BASE/build` and `$BASE/torch/include` are generated
136        # so we don't want to flag their contents
137        if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"):
138            # Curtail search by modifying dirnames and filenames in place
139            # Yes, this is the way to do this, see `help(os.walk)`
140            dirnames[:] = []
141            continue
142
143        for x in filenames:
144            filename = os.path.join(root, x)
145            file_result = check_file(filename)
146            if file_result > 0:
147                kernels_without_checks += file_result
148                files_without_checks.append(filename)
149
150    if kernels_without_checks > 0:
151        count_str = f"Found {kernels_without_checks} instances in " \
152                    f"{len(files_without_checks)} files where kernel " \
153                    "launches didn't have checks."
154        print(count_str, file=sys.stderr)
155        print("Files without checks:", file=sys.stderr)
156        for x in files_without_checks:
157            print(f"\t{x}", file=sys.stderr)
158        print(count_str, file=sys.stderr)
159
160    return kernels_without_checks
161
162
163if __name__ == "__main__":
164    unsafe_launches = check_cuda_kernel_launches()
165    sys.exit(0 if unsafe_launches == 0 else 1)
166