xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/gen_pyi.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import os
3import pathlib
4from collections import defaultdict
5from typing import Any, Dict, List, Set, Tuple, Union
6
7
8def materialize_lines(lines: List[str], indentation: int) -> str:
9    output = ""
10    new_line_with_indent = "\n" + " " * indentation
11    for i, line in enumerate(lines):
12        if i != 0:
13            output += new_line_with_indent
14        output += line.replace("\n", new_line_with_indent)
15    return output
16
17
18def gen_from_template(
19    dir: str,
20    template_name: str,
21    output_name: str,
22    replacements: List[Tuple[str, Any, int]],
23):
24    template_path = os.path.join(dir, template_name)
25    output_path = os.path.join(dir, output_name)
26
27    with open(template_path) as f:
28        content = f.read()
29    for placeholder, lines, indentation in replacements:
30        with open(output_path, "w") as f:
31            content = content.replace(
32                placeholder, materialize_lines(lines, indentation)
33            )
34            f.write(content)
35
36
37def find_file_paths(dir_paths: List[str], files_to_exclude: Set[str]) -> Set[str]:
38    """
39    When given a path to a directory, returns the paths to the relevant files within it.
40
41    This function does NOT recursive traverse to subdirectories.
42    """
43    paths: Set[str] = set()
44    for dir_path in dir_paths:
45        all_files = os.listdir(dir_path)
46        python_files = {fname for fname in all_files if ".py" == fname[-3:]}
47        filter_files = {
48            fname for fname in python_files if fname not in files_to_exclude
49        }
50        paths.update({os.path.join(dir_path, fname) for fname in filter_files})
51    return paths
52
53
54def extract_method_name(line: str) -> str:
55    """Extract method name from decorator in the form of "@functional_datapipe({method_name})"."""
56    if '("' in line:
57        start_token, end_token = '("', '")'
58    elif "('" in line:
59        start_token, end_token = "('", "')"
60    else:
61        raise RuntimeError(
62            f"Unable to find appropriate method name within line:\n{line}"
63        )
64    start, end = line.find(start_token) + len(start_token), line.find(end_token)
65    return line[start:end]
66
67
68def extract_class_name(line: str) -> str:
69    """Extract class name from class definition in the form of "class {CLASS_NAME}({Type}):"."""
70    start_token = "class "
71    end_token = "("
72    start, end = line.find(start_token) + len(start_token), line.find(end_token)
73    return line[start:end]
74
75
76def parse_datapipe_file(
77    file_path: str,
78) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
79    """Given a path to file, parses the file and returns a dictionary of method names to function signatures."""
80    method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
81    doc_string_dict = defaultdict(list)
82    with open(file_path) as f:
83        open_paren_count = 0
84        method_name, class_name, signature = "", "", ""
85        skip = False
86        for line in f:
87            if line.count('"""') % 2 == 1:
88                skip = not skip
89            if skip or '"""' in line:  # Saving docstrings
90                doc_string_dict[method_name].append(line)
91                continue
92            if "@functional_datapipe" in line:
93                method_name = extract_method_name(line)
94                doc_string_dict[method_name] = []
95                continue
96            if method_name and "class " in line:
97                class_name = extract_class_name(line)
98                continue
99            if method_name and ("def __init__(" in line or "def __new__(" in line):
100                if "def __new__(" in line:
101                    special_output_type.add(method_name)
102                open_paren_count += 1
103                start = line.find("(") + len("(")
104                line = line[start:]
105            if open_paren_count > 0:
106                open_paren_count += line.count("(")
107                open_paren_count -= line.count(")")
108                if open_paren_count == 0:
109                    end = line.rfind(")")
110                    signature += line[:end]
111                    method_to_signature[method_name] = process_signature(signature)
112                    method_to_class_name[method_name] = class_name
113                    method_name, class_name, signature = "", "", ""
114                elif open_paren_count < 0:
115                    raise RuntimeError(
116                        "open parenthesis count < 0. This shouldn't be possible."
117                    )
118                else:
119                    signature += line.strip("\n").strip(" ")
120    return (
121        method_to_signature,
122        method_to_class_name,
123        special_output_type,
124        doc_string_dict,
125    )
126
127
128def parse_datapipe_files(
129    file_paths: Set[str],
130) -> Tuple[Dict[str, str], Dict[str, str], Set[str], Dict[str, List[str]]]:
131    (
132        methods_and_signatures,
133        methods_and_class_names,
134        methods_with_special_output_types,
135    ) = ({}, {}, set())
136    methods_and_doc_strings = {}
137    for path in file_paths:
138        (
139            method_to_signature,
140            method_to_class_name,
141            methods_needing_special_output_types,
142            doc_string_dict,
143        ) = parse_datapipe_file(path)
144        methods_and_signatures.update(method_to_signature)
145        methods_and_class_names.update(method_to_class_name)
146        methods_with_special_output_types.update(methods_needing_special_output_types)
147        methods_and_doc_strings.update(doc_string_dict)
148    return (
149        methods_and_signatures,
150        methods_and_class_names,
151        methods_with_special_output_types,
152        methods_and_doc_strings,
153    )
154
155
156def split_outside_bracket(line: str, delimiter: str = ",") -> List[str]:
157    """Given a line of text, split it on comma unless the comma is within a bracket '[]'."""
158    bracket_count = 0
159    curr_token = ""
160    res = []
161    for char in line:
162        if char == "[":
163            bracket_count += 1
164        elif char == "]":
165            bracket_count -= 1
166        elif char == delimiter and bracket_count == 0:
167            res.append(curr_token)
168            curr_token = ""
169            continue
170        curr_token += char
171    res.append(curr_token)
172    return res
173
174
175def process_signature(line: str) -> str:
176    """
177    Clean up a given raw function signature.
178
179    This includes removing the self-referential datapipe argument, default
180    arguments of input functions, newlines, and spaces.
181    """
182    tokens: List[str] = split_outside_bracket(line)
183    for i, token in enumerate(tokens):
184        tokens[i] = token.strip(" ")
185        if token == "cls":
186            tokens[i] = "self"
187        elif i > 0 and ("self" == tokens[i - 1]) and (tokens[i][0] != "*"):
188            # Remove the datapipe after 'self' or 'cls' unless it has '*'
189            tokens[i] = ""
190        elif "Callable =" in token:  # Remove default argument if it is a function
191            head, default_arg = token.rsplit("=", 2)
192            tokens[i] = head.strip(" ") + "= ..."
193    tokens = [t for t in tokens if t != ""]
194    line = ", ".join(tokens)
195    return line
196
197
198def get_method_definitions(
199    file_path: Union[str, List[str]],
200    files_to_exclude: Set[str],
201    deprecated_files: Set[str],
202    default_output_type: str,
203    method_to_special_output_type: Dict[str, str],
204    root: str = "",
205) -> List[str]:
206    """
207    #.pyi generation for functional DataPipes Process.
208
209    # 1. Find files that we want to process (exclude the ones who don't)
210    # 2. Parse method name and signature
211    # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
212    """
213    if root == "":
214        root = str(pathlib.Path(__file__).parent.resolve())
215    file_path = [file_path] if isinstance(file_path, str) else file_path
216    file_path = [os.path.join(root, path) for path in file_path]
217    file_paths = find_file_paths(
218        file_path, files_to_exclude=files_to_exclude.union(deprecated_files)
219    )
220    (
221        methods_and_signatures,
222        methods_and_class_names,
223        methods_w_special_output_types,
224        methods_and_doc_strings,
225    ) = parse_datapipe_files(file_paths)
226
227    for fn_name in method_to_special_output_type:
228        if fn_name not in methods_w_special_output_types:
229            methods_w_special_output_types.add(fn_name)
230
231    method_definitions = []
232    for method_name, arguments in methods_and_signatures.items():
233        class_name = methods_and_class_names[method_name]
234        if method_name in methods_w_special_output_types:
235            output_type = method_to_special_output_type[method_name]
236        else:
237            output_type = default_output_type
238        doc_string = "".join(methods_and_doc_strings[method_name])
239        if doc_string == "":
240            doc_string = "    ...\n"
241        method_definitions.append(
242            f"# Functional form of '{class_name}'\n"
243            f"def {method_name}({arguments}) -> {output_type}:\n"
244            f"{doc_string}"
245        )
246    method_definitions.sort(
247        key=lambda s: s.split("\n")[1]
248    )  # sorting based on method_name
249
250    return method_definitions
251
252
253# Defined outside of main() so they can be imported by TorchData
254iterDP_file_path: str = "iter"
255iterDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
256iterDP_deprecated_files: Set[str] = set()
257iterDP_method_to_special_output_type: Dict[str, str] = {
258    "demux": "List[IterDataPipe]",
259    "fork": "List[IterDataPipe]",
260}
261
262mapDP_file_path: str = "map"
263mapDP_files_to_exclude: Set[str] = {"__init__.py", "utils.py"}
264mapDP_deprecated_files: Set[str] = set()
265mapDP_method_to_special_output_type: Dict[str, str] = {"shuffle": "IterDataPipe"}
266
267
268def main() -> None:
269    """
270    # Inject file into template datapipe.pyi.in.
271
272    TODO: The current implementation of this script only generates interfaces for built-in methods. To generate
273          interface for user-defined DataPipes, consider changing `IterDataPipe.register_datapipe_as_function`.
274    """
275    iter_method_definitions = get_method_definitions(
276        iterDP_file_path,
277        iterDP_files_to_exclude,
278        iterDP_deprecated_files,
279        "IterDataPipe",
280        iterDP_method_to_special_output_type,
281    )
282
283    map_method_definitions = get_method_definitions(
284        mapDP_file_path,
285        mapDP_files_to_exclude,
286        mapDP_deprecated_files,
287        "MapDataPipe",
288        mapDP_method_to_special_output_type,
289    )
290
291    path = pathlib.Path(__file__).parent.resolve()
292    replacements = [
293        ("${IterDataPipeMethods}", iter_method_definitions, 4),
294        ("${MapDataPipeMethods}", map_method_definitions, 4),
295    ]
296    gen_from_template(
297        dir=str(path),
298        template_name="datapipe.pyi.in",
299        output_name="datapipe.pyi",
300        replacements=replacements,
301    )
302
303
304if __name__ == "__main__":
305    main()
306