xref: /aosp_15_r20/external/pytorch/docs/source/scripts/exportdb/generate_example_rst.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import inspect
2import os
3import re
4from pathlib import Path
5
6import torch
7import torch._dynamo as torchdynamo
8from torch._export.db.case import ExportCase
9from torch._export.db.examples import all_examples
10from torch.export import export
11
12
13PWD = Path(__file__).absolute().parent
14ROOT = Path(__file__).absolute().parent.parent.parent.parent
15SOURCE = ROOT / Path("source")
16EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb")
17
18
19def generate_example_rst(example_case: ExportCase):
20    """
21    Generates the .rst files for all the examples in db/examples/
22    """
23
24    model = example_case.model
25
26    tags = ", ".join(f":doc:`{tag} <{tag}>`" for tag in example_case.tags)
27
28    source_file = (
29        inspect.getfile(model.__class__)
30        if isinstance(model, torch.nn.Module)
31        else inspect.getfile(model)
32    )
33    with open(source_file) as file:
34        source_code = file.read()
35    source_code = source_code.replace("\n", "\n    ")
36    splitted_source_code = re.split(r"@export_rewrite_case.*\n", source_code)
37
38    assert len(splitted_source_code) in {
39        1,
40        2,
41    }, f"more than one @export_rewrite_case decorator in {source_code}"
42
43    more_arguments = ""
44    if example_case.example_kwargs:
45        more_arguments += ", example_kwargs"
46    if example_case.dynamic_shapes:
47        more_arguments += ", dynamic_shapes=dynamic_shapes"
48
49    # Generate contents of the .rst file
50    title = f"{example_case.name}"
51    doc_contents = f"""{title}
52{'^' * (len(title))}
53
54.. note::
55
56    Tags: {tags}
57
58    Support Level: {example_case.support_level.name}
59
60Original source code:
61
62.. code-block:: python
63
64    {splitted_source_code[0]}
65
66    torch.export.export(model, example_args{more_arguments})
67
68Result:
69
70.. code-block::
71
72"""
73
74    # Get resulting graph from dynamo trace
75    try:
76        exported_program = export(
77            model,
78            example_case.example_args,
79            example_case.example_kwargs,
80            dynamic_shapes=example_case.dynamic_shapes,
81        )
82        graph_output = str(exported_program)
83        graph_output = re.sub(r"        # File(.|\n)*?\n", "", graph_output)
84        graph_output = graph_output.replace("\n", "\n    ")
85        output = f"    {graph_output}"
86    except torchdynamo.exc.Unsupported as e:
87        output = "    Unsupported: " + str(e).split("\n")[0]
88    except AssertionError as e:
89        output = "    AssertionError: " + str(e).split("\n")[0]
90    except RuntimeError as e:
91        output = "    RuntimeError: " + str(e).split("\n")[0]
92
93    doc_contents += output + "\n"
94
95    if len(splitted_source_code) == 2:
96        doc_contents += f"""\n
97You can rewrite the example above to something like the following:
98
99.. code-block:: python
100
101{splitted_source_code[1]}
102
103"""
104
105    return doc_contents
106
107
108def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules):
109    """
110    Generates the index.rst file
111    """
112
113    support_contents = ""
114    for k, v in support_level_to_modules.items():
115        support_level = k.name.lower().replace("_", " ").title()
116        module_contents = "\n\n".join(v)
117        support_contents += f"""
118{support_level}
119{'-' * (len(support_level))}
120
121{module_contents}
122"""
123
124    tag_names = "\n    ".join(t for t in tag_to_modules.keys())
125
126    with open(os.path.join(PWD, "blurb.txt")) as file:
127        blurb = file.read()
128
129    # Generate contents of the .rst file
130    doc_contents = f""".. _torch.export_db:
131
132ExportDB
133========
134
135{blurb}
136
137.. toctree::
138    :maxdepth: 1
139    :caption: Tags
140
141    {tag_names}
142
143{support_contents}
144"""
145
146    with open(os.path.join(EXPORTDB_SOURCE, "index.rst"), "w") as f:
147        f.write(doc_contents)
148
149
150def generate_tag_rst(tag_to_modules):
151    """
152    For each tag that shows up in each ExportCase.tag, generate an .rst file
153    containing all the examples that have that tag.
154    """
155
156    for tag, modules_rst in tag_to_modules.items():
157        doc_contents = f"{tag}\n{'=' * (len(tag) + 4)}\n"
158        full_modules_rst = "\n\n".join(modules_rst)
159        full_modules_rst = re.sub(
160            r"={3,}", lambda match: "-" * len(match.group()), full_modules_rst
161        )
162        doc_contents += full_modules_rst
163
164        with open(os.path.join(EXPORTDB_SOURCE, f"{tag}.rst"), "w") as f:
165            f.write(doc_contents)
166
167
168def generate_rst():
169    if not os.path.exists(EXPORTDB_SOURCE):
170        os.makedirs(EXPORTDB_SOURCE)
171
172    example_cases = all_examples()
173    tag_to_modules = {}
174    support_level_to_modules = {}
175    for example_case in example_cases.values():
176        doc_contents = generate_example_rst(example_case)
177
178        for tag in example_case.tags:
179            tag_to_modules.setdefault(tag, []).append(doc_contents)
180
181        support_level_to_modules.setdefault(example_case.support_level, []).append(
182            doc_contents
183        )
184
185    generate_tag_rst(tag_to_modules)
186    generate_index_rst(example_cases, tag_to_modules, support_level_to_modules)
187
188
189if __name__ == "__main__":
190    generate_rst()
191