1import inspect 2import os 3import re 4from pathlib import Path 5 6import torch 7import torch._dynamo as torchdynamo 8from torch._export.db.case import ExportCase 9from torch._export.db.examples import all_examples 10from torch.export import export 11 12 13PWD = Path(__file__).absolute().parent 14ROOT = Path(__file__).absolute().parent.parent.parent.parent 15SOURCE = ROOT / Path("source") 16EXPORTDB_SOURCE = SOURCE / Path("generated") / Path("exportdb") 17 18 19def generate_example_rst(example_case: ExportCase): 20 """ 21 Generates the .rst files for all the examples in db/examples/ 22 """ 23 24 model = example_case.model 25 26 tags = ", ".join(f":doc:`{tag} <{tag}>`" for tag in example_case.tags) 27 28 source_file = ( 29 inspect.getfile(model.__class__) 30 if isinstance(model, torch.nn.Module) 31 else inspect.getfile(model) 32 ) 33 with open(source_file) as file: 34 source_code = file.read() 35 source_code = source_code.replace("\n", "\n ") 36 splitted_source_code = re.split(r"@export_rewrite_case.*\n", source_code) 37 38 assert len(splitted_source_code) in { 39 1, 40 2, 41 }, f"more than one @export_rewrite_case decorator in {source_code}" 42 43 more_arguments = "" 44 if example_case.example_kwargs: 45 more_arguments += ", example_kwargs" 46 if example_case.dynamic_shapes: 47 more_arguments += ", dynamic_shapes=dynamic_shapes" 48 49 # Generate contents of the .rst file 50 title = f"{example_case.name}" 51 doc_contents = f"""{title} 52{'^' * (len(title))} 53 54.. note:: 55 56 Tags: {tags} 57 58 Support Level: {example_case.support_level.name} 59 60Original source code: 61 62.. code-block:: python 63 64 {splitted_source_code[0]} 65 66 torch.export.export(model, example_args{more_arguments}) 67 68Result: 69 70.. code-block:: 71 72""" 73 74 # Get resulting graph from dynamo trace 75 try: 76 exported_program = export( 77 model, 78 example_case.example_args, 79 example_case.example_kwargs, 80 dynamic_shapes=example_case.dynamic_shapes, 81 ) 82 graph_output = str(exported_program) 83 graph_output = re.sub(r" # File(.|\n)*?\n", "", graph_output) 84 graph_output = graph_output.replace("\n", "\n ") 85 output = f" {graph_output}" 86 except torchdynamo.exc.Unsupported as e: 87 output = " Unsupported: " + str(e).split("\n")[0] 88 except AssertionError as e: 89 output = " AssertionError: " + str(e).split("\n")[0] 90 except RuntimeError as e: 91 output = " RuntimeError: " + str(e).split("\n")[0] 92 93 doc_contents += output + "\n" 94 95 if len(splitted_source_code) == 2: 96 doc_contents += f"""\n 97You can rewrite the example above to something like the following: 98 99.. code-block:: python 100 101{splitted_source_code[1]} 102 103""" 104 105 return doc_contents 106 107 108def generate_index_rst(example_cases, tag_to_modules, support_level_to_modules): 109 """ 110 Generates the index.rst file 111 """ 112 113 support_contents = "" 114 for k, v in support_level_to_modules.items(): 115 support_level = k.name.lower().replace("_", " ").title() 116 module_contents = "\n\n".join(v) 117 support_contents += f""" 118{support_level} 119{'-' * (len(support_level))} 120 121{module_contents} 122""" 123 124 tag_names = "\n ".join(t for t in tag_to_modules.keys()) 125 126 with open(os.path.join(PWD, "blurb.txt")) as file: 127 blurb = file.read() 128 129 # Generate contents of the .rst file 130 doc_contents = f""".. _torch.export_db: 131 132ExportDB 133======== 134 135{blurb} 136 137.. toctree:: 138 :maxdepth: 1 139 :caption: Tags 140 141 {tag_names} 142 143{support_contents} 144""" 145 146 with open(os.path.join(EXPORTDB_SOURCE, "index.rst"), "w") as f: 147 f.write(doc_contents) 148 149 150def generate_tag_rst(tag_to_modules): 151 """ 152 For each tag that shows up in each ExportCase.tag, generate an .rst file 153 containing all the examples that have that tag. 154 """ 155 156 for tag, modules_rst in tag_to_modules.items(): 157 doc_contents = f"{tag}\n{'=' * (len(tag) + 4)}\n" 158 full_modules_rst = "\n\n".join(modules_rst) 159 full_modules_rst = re.sub( 160 r"={3,}", lambda match: "-" * len(match.group()), full_modules_rst 161 ) 162 doc_contents += full_modules_rst 163 164 with open(os.path.join(EXPORTDB_SOURCE, f"{tag}.rst"), "w") as f: 165 f.write(doc_contents) 166 167 168def generate_rst(): 169 if not os.path.exists(EXPORTDB_SOURCE): 170 os.makedirs(EXPORTDB_SOURCE) 171 172 example_cases = all_examples() 173 tag_to_modules = {} 174 support_level_to_modules = {} 175 for example_case in example_cases.values(): 176 doc_contents = generate_example_rst(example_case) 177 178 for tag in example_case.tags: 179 tag_to_modules.setdefault(tag, []).append(doc_contents) 180 181 support_level_to_modules.setdefault(example_case.support_level, []).append( 182 doc_contents 183 ) 184 185 generate_tag_rst(tag_to_modules) 186 generate_index_rst(example_cases, tag_to_modules, support_level_to_modules) 187 188 189if __name__ == "__main__": 190 generate_rst() 191