1# mypy: ignore-errors 2 3import os 4import re 5import sys 6from typing import List 7 8__all__ = [ 9 "check_code_for_cuda_kernel_launches", 10 "check_cuda_kernel_launches", 11] 12 13# FILES TO EXCLUDE (match is done with suffix using `endswith`) 14# You wouldn't drive without a seatbelt, though, so why would you 15# launch a kernel without some safety? Use this as a quick workaround 16# for a problem with the checker, fix the checker, then de-exclude 17# the files in question. 18exclude_files: List[str] = [] 19 20# Without using a C++ AST we can't 100% detect kernel launches, so we 21# model them as having the pattern "<<<parameters>>>(arguments);" 22# We then require that `C10_CUDA_KERNEL_LAUNCH_CHECK` be 23# the next statement. 24# 25# We model the next statement as ending at the next `}` or `;`. 26# If we see `}` then a clause ended (bad) if we see a semi-colon then 27# we expect the launch check just before it. 28# 29# Since the kernel launch can include lambda statements, it's important 30# to find the correct end-paren of the kernel launch. Doing this with 31# pure regex requires recursive regex, which aren't part of the Python 32# standard library. To avoid an additional dependency, we build a prefix 33# regex that finds the start of a kernel launch, use a paren-matching 34# algorithm to find the end of the launch, and then another regex to 35# determine if a launch check is present. 36 37# Finds potential starts of kernel launches 38kernel_launch_start = re.compile( 39 r"^.*<<<[^>]+>>>\s*\(", flags=re.MULTILINE 40) 41 42# This pattern should start at the character after the final paren of the 43# kernel launch. It returns a match if the launch check is not the next statement 44has_check = re.compile( 45 r"\s*;(?![^;}]*C10_CUDA_KERNEL_LAUNCH_CHECK\(\);)", flags=re.MULTILINE 46) 47 48def find_matching_paren(s: str, startpos: int) -> int: 49 """Given a string "prefix (unknown number of characters) suffix" 50 and the position of the first `(` returns the index of the character 51 1 past the `)`, accounting for paren nesting 52 """ 53 opening = 0 54 for i, c in enumerate(s[startpos:]): 55 if c == '(': 56 opening += 1 57 elif c == ')': 58 opening -= 1 59 if opening == 0: 60 return startpos + i + 1 61 62 raise IndexError("Closing parens not found!") 63 64 65def should_exclude_file(filename) -> bool: 66 for exclude_suffix in exclude_files: 67 if filename.endswith(exclude_suffix): 68 return True 69 return False 70 71 72def check_code_for_cuda_kernel_launches(code, filename=None): 73 """Checks code for CUDA kernel launches without cuda error checks. 74 75 Args: 76 filename - Filename of file containing the code. Used only for display 77 purposes, so you can put anything here. 78 code - The code to check 79 80 Returns: 81 The number of unsafe kernel launches in the code 82 """ 83 if filename is None: 84 filename = "##Python Function Call##" 85 86 # We break the code apart and put it back together to add 87 # helpful line numberings for identifying problem areas 88 code = enumerate(code.split("\n")) # Split by line breaks 89 code = [f"{lineno}: {linecode}" for lineno, linecode in code] # Number the lines 90 code = '\n'.join(code) # Put it back together 91 92 num_launches_without_checks = 0 93 for m in kernel_launch_start.finditer(code): 94 end_paren = find_matching_paren(code, m.end() - 1) 95 if has_check.match(code, end_paren): 96 num_launches_without_checks += 1 97 context = code[m.start():end_paren + 1] 98 print(f"Missing C10_CUDA_KERNEL_LAUNCH_CHECK in '{filename}'. Context:\n{context}", file=sys.stderr) 99 100 return num_launches_without_checks 101 102 103def check_file(filename): 104 """Checks a file for CUDA kernel launches without cuda error checks 105 106 Args: 107 filename - File to check 108 109 Returns: 110 The number of unsafe kernel launches in the file 111 """ 112 if not (filename.endswith((".cu", ".cuh"))): 113 return 0 114 if should_exclude_file(filename): 115 return 0 116 with open(filename) as fo: 117 contents = fo.read() 118 unsafeCount = check_code_for_cuda_kernel_launches(contents, filename) 119 return unsafeCount 120 121 122def check_cuda_kernel_launches(): 123 """Checks all pytorch code for CUDA kernel launches without cuda error checks 124 125 Returns: 126 The number of unsafe kernel launches in the codebase 127 """ 128 torch_dir = os.path.dirname(os.path.realpath(__file__)) 129 torch_dir = os.path.dirname(torch_dir) # Go up to parent torch 130 torch_dir = os.path.dirname(torch_dir) # Go up to parent caffe2 131 132 kernels_without_checks = 0 133 files_without_checks = [] 134 for root, dirnames, filenames in os.walk(torch_dir): 135 # `$BASE/build` and `$BASE/torch/include` are generated 136 # so we don't want to flag their contents 137 if root == os.path.join(torch_dir, "build") or root == os.path.join(torch_dir, "torch/include"): 138 # Curtail search by modifying dirnames and filenames in place 139 # Yes, this is the way to do this, see `help(os.walk)` 140 dirnames[:] = [] 141 continue 142 143 for x in filenames: 144 filename = os.path.join(root, x) 145 file_result = check_file(filename) 146 if file_result > 0: 147 kernels_without_checks += file_result 148 files_without_checks.append(filename) 149 150 if kernels_without_checks > 0: 151 count_str = f"Found {kernels_without_checks} instances in " \ 152 f"{len(files_without_checks)} files where kernel " \ 153 "launches didn't have checks." 154 print(count_str, file=sys.stderr) 155 print("Files without checks:", file=sys.stderr) 156 for x in files_without_checks: 157 print(f"\t{x}", file=sys.stderr) 158 print(count_str, file=sys.stderr) 159 160 return kernels_without_checks 161 162 163if __name__ == "__main__": 164 unsafe_launches = check_cuda_kernel_launches() 165 sys.exit(0 if unsafe_launches == 0 else 1) 166