1# mypy: allow-untyped-defs 2import os 3import pathlib 4from collections import defaultdict 5from typing import Any, Dict, List, Set, Tuple, Union 6 7 8def materialize_lines(lines: List[str], indentation: int) -> str: 9 output = "" 10 new_line_with_indent = "\n" + " " * indentation 11 for i, line in enumerate(lines): 12 if i != 0: 13 output += new_line_with_indent 14 output += line.replace("\n", new_line_with_indent) 15 return output 16 17 18def gen_from_template( 19 dir: str, 20 template_name: str, 21 output_name: str, 22 replacements: List[Tuple[str, Any, int]], 23): 24 template_path = os.path.join(dir, template_name) 25 output_path = os.path.join(dir, output_name) 26 27 with open(template_path) as f: 28 content = f.read() 29 for placeholder, lines, indentation in replacements: 30 with open(output_path, "w") as f: 31 content = content.replace( 32 placeholder, materialize_lines(lines, indentation) 33 ) 34 f.write(content) 35 36 37def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str]: 38 """ 39 When given a path to a directory, returns the paths to the relevant files within it. 40 41 This function does NOT recursive traverse to subdirectories. 42 """ 43 paths: Set[str] = set() 44 for dir_path in dir_paths: 45 all_files = os.listdir(dir_path) 46 python_files = {fname for fname in all_files if ".py" == fname[-3:]} 47 filter_files = { 48 fname for fname in python_files if fname not in files_to_exclude 49 } 50 paths.update({os.path.join(dir_path, fname) for fname in filter_files}) 51 return paths 52 53 54def extract_method_name(line: str) -> str: 55 """Extract method name from decorator in the form of "@functional_datapipe({method_name})".""" 56 if '("' in line: 57 start_token, end_token = '("', '")' 58 elif "('" in line: 59 start_token, end_token = "('", "')" 60 else: 61 raise RuntimeError( 62 f"Unable to find appropriate method name within line:\n{line}" 63 ) 64 start, end = line.find(start_token) + len(start_token), line.find(end_token) 65 return line[start:end] 66 67 68def extract_class_name(line: str) -> str: 69 """Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):".""" 70 start_token = "class " 71 end_token = "(" 72 start, end = line.find(start_token) + len(start_token), line.find(end_token) 73 return line[start:end] 74 75 76def parse_datapipe_file( 77 file_path: str, 78) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]: 79 """Given a path to file, parses the file and returns a dictionary of method names to function signatures.""" 80 method_to_signature, method_to_class_name, special_output_type = {}, {}, set() 81 doc_string_dict = defaultdict(list) 82 with open(file_path) as f: 83 open_paren_count = 0 84 method_name, class_name, signature = "", "", "" 85 skip = False 86 for line in f: 87 if line.count('"""') % 2 == 1: 88 skip = not skip 89 if skip or '"""' in line: # Saving docstrings 90 doc_string_dict[method_name].append(line) 91 continue 92 if "@functional_datapipe" in line: 93 method_name = extract_method_name(line) 94 doc_string_dict[method_name] = [] 95 continue 96 if method_name and "class " in line: 97 class_name = extract_class_name(line) 98 continue 99 if method_name and ("def __init__(" in line or "def __new__(" in line): 100 if "def __new__(" in line: 101 special_output_type.add(method_name) 102 open_paren_count += 1 103 start = line.find("(") + len("(") 104 line = line[start:] 105 if open_paren_count > 0: 106 open_paren_count += line.count("(") 107 open_paren_count -= line.count(")") 108 if open_paren_count == 0: 109 end = line.rfind(")") 110 signature += line[:end] 111 method_to_signature[method_name] = process_signature(signature) 112 method_to_class_name[method_name] = class_name 113 method_name, class_name, signature = "", "", "" 114 elif open_paren_count < 0: 115 raise RuntimeError( 116 "open parenthesis count < 0. This shouldn't be possible." 117 ) 118 else: 119 signature += line.strip("\n").strip(" ") 120 return ( 121 method_to_signature, 122 method_to_class_name, 123 special_output_type, 124 doc_string_dict, 125 ) 126 127 128def parse_datapipe_files( 129 file_paths: Set[str], 130) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]: 131 ( 132 methods_and_signatures, 133 methods_and_class_names, 134 methods_with_special_output_types, 135 ) = ({}, {}, set()) 136 methods_and_doc_strings = {} 137 for path in file_paths: 138 ( 139 method_to_signature, 140 method_to_class_name, 141 methods_needing_special_output_types, 142 doc_string_dict, 143 ) = parse_datapipe_file(path) 144 methods_and_signatures.update(method_to_signature) 145 methods_and_class_names.update(method_to_class_name) 146 methods_with_special_output_types.update(methods_needing_special_output_types) 147 methods_and_doc_strings.update(doc_string_dict) 148 return ( 149 methods_and_signatures, 150 methods_and_class_names, 151 methods_with_special_output_types, 152 methods_and_doc_strings, 153 ) 154 155 156def split_outside_bracket(line: str, delimiter: str = ",") -> List[str]: 157 """Given a line of text, split it on comma unless the comma is within a bracket '[]'.""" 158 bracket_count = 0 159 curr_token = "" 160 res = [] 161 for char in line: 162 if char == "[": 163 bracket_count += 1 164 elif char == "]": 165 bracket_count -= 1 166 elif char == delimiter and bracket_count == 0: 167 res.append(curr_token) 168 curr_token = "" 169 continue 170 curr_token += char 171 res.append(curr_token) 172 return res 173 174 175def process_signature(line: str) -> str: 176 """ 177 Clean up a given raw function signature. 178 179 This includes removing the self-referential datapipe argument, default 180 arguments of input functions, newlines, and spaces. 181 """ 182 tokens: List[str] = split_outside_bracket(line) 183 for i, token in enumerate(tokens): 184 tokens[i] = token.strip(" ") 185 if token == "cls": 186 tokens[i] = "self" 187 elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"): 188 # Remove the datapipe after 'self' or 'cls' unless it has '*' 189 tokens[i] = "" 190 elif "Callable =" in token: # Remove default argument if it is a function 191 head, default_arg = token.rsplit("=", 2) 192 tokens[i] = head.strip(" ") + "= ..." 193 tokens = [t for t in tokens if t != ""] 194 line = ", ".join(tokens) 195 return line 196 197 198def get_method_definitions( 199 file_path: Union[str, List[str]], 200 files_to_exclude: Set[str], 201 deprecated_files: Set[str], 202 default_output_type: str, 203 method_to_special_output_type: Dict[str, str], 204 root: str = "", 205) -> List[str]: 206 """ 207 #.pyi generation for functional DataPipes Process. 208 209 # 1. Find files that we want to process (exclude the ones who don't) 210 # 2. Parse method name and signature 211 # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces 212 """ 213 if root == "": 214 root = str(pathlib.Path(__file__).parent.resolve()) 215 file_path = [file_path] if isinstance(file_path, str) else file_path 216 file_path = [os.path.join(root, path) for path in file_path] 217 file_paths = find_file_paths( 218 file_path, files_to_exclude=files_to_exclude.union(deprecated_files) 219 ) 220 ( 221 methods_and_signatures, 222 methods_and_class_names, 223 methods_w_special_output_types, 224 methods_and_doc_strings, 225 ) = parse_datapipe_files(file_paths) 226 227 for fn_name in method_to_special_output_type: 228 if fn_name not in methods_w_special_output_types: 229 methods_w_special_output_types.add(fn_name) 230 231 method_definitions = [] 232 for method_name, arguments in methods_and_signatures.items(): 233 class_name = methods_and_class_names[method_name] 234 if method_name in methods_w_special_output_types: 235 output_type = method_to_special_output_type[method_name] 236 else: 237 output_type = default_output_type 238 doc_string = "".join(methods_and_doc_strings[method_name]) 239 if doc_string == "": 240 doc_string = " ...\n" 241 method_definitions.append( 242 f"# Functional form of '{class_name}'\n" 243 f"def {method_name}({arguments}) -> {output_type}:\n" 244 f"{doc_string}" 245 ) 246 method_definitions.sort( 247 key=lambda s: s.split("\n")[1] 248 ) # sorting based on method_name 249 250 return method_definitions 251 252 253# Defined outside of main() so they can be imported by TorchData 254iterDP_file_path: str = "iter" 255iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"} 256iterDP_deprecated_files: Set[str] = set() 257iterDP_method_to_special_output_type: Dict[str, str] = { 258 "demux": "List[IterDataPipe]", 259 "fork": "List[IterDataPipe]", 260} 261 262mapDP_file_path: str = "map" 263mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"} 264mapDP_deprecated_files: Set[str] = set() 265mapDP_method_to_special_output_type: Dict[str, str] = {"shuffle": "IterDataPipe"} 266 267 268def main() -> None: 269 """ 270 # Inject file into template datapipe.pyi.in. 271 272 TODO: The current implementation of this script only generates interfaces for built-in methods. To generate 273 interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`. 274 """ 275 iter_method_definitions = get_method_definitions( 276 iterDP_file_path, 277 iterDP_files_to_exclude, 278 iterDP_deprecated_files, 279 "IterDataPipe", 280 iterDP_method_to_special_output_type, 281 ) 282 283 map_method_definitions = get_method_definitions( 284 mapDP_file_path, 285 mapDP_files_to_exclude, 286 mapDP_deprecated_files, 287 "MapDataPipe", 288 mapDP_method_to_special_output_type, 289 ) 290 291 path = pathlib.Path(__file__).parent.resolve() 292 replacements = [ 293 ("${IterDataPipeMethods}", iter_method_definitions, 4), 294 ("${MapDataPipeMethods}", map_method_definitions, 4), 295 ] 296 gen_from_template( 297 dir=str(path), 298 template_name="datapipe.pyi.in", 299 output_name="datapipe.pyi", 300 replacements=replacements, 301 ) 302 303 304if __name__ == "__main__": 305 main() 306