xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/benchmarks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3import argparse
4import os
5import sys
6from typing import Set
7
8
9# Note - hf and timm have their own version of this, torchbench does not
10# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this...
11def model_names(filename: str) -> Set[str]:
12    names = set()
13    with open(filename) as fh:
14        lines = fh.readlines()
15        lines = [line.rstrip() for line in lines]
16        for line in lines:
17            line_parts = line.split(" ")
18            if len(line_parts) == 1:
19                line_parts = line.split(",")
20            model_name = line_parts[0]
21            names.add(model_name)
22    return names
23
24
25TIMM_MODEL_NAMES = model_names(
26    os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
27)
28HF_MODELS_FILE_NAME = model_names(
29    os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
30)
31TORCHBENCH_MODELS_FILE_NAME = model_names(
32    os.path.join(os.path.dirname(__file__), "all_torchbench_models_list.txt")
33)
34
35# timm <> HF disjoint
36assert TIMM_MODEL_NAMES.isdisjoint(HF_MODELS_FILE_NAME)
37# timm <> torch disjoint
38assert TIMM_MODEL_NAMES.isdisjoint(TORCHBENCH_MODELS_FILE_NAME)
39# torch <> hf disjoint
40assert TORCHBENCH_MODELS_FILE_NAME.isdisjoint(HF_MODELS_FILE_NAME)
41
42
43def parse_args(args=None):
44    parser = argparse.ArgumentParser()
45    parser.add_argument(
46        "--only",
47        help="""Run just one model from whichever model suite it belongs to. Or
48        specify the path and class name of the model in format like:
49        --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME>
50
51        Due to the fact that dynamo changes current working directory,
52        the path should be an absolute path.
53
54        The class should have a method get_example_inputs to return the inputs
55        for the model. An example looks like
56        ```
57        class LinearModel(nn.Module):
58            def __init__(self):
59                super().__init__()
60                self.linear = nn.Linear(10, 10)
61
62            def forward(self, x):
63                return self.linear(x)
64
65            def get_example_inputs(self):
66                return (torch.randn(2, 10),)
67        ```
68    """,
69    )
70    return parser.parse_known_args(args)
71
72
73if __name__ == "__main__":
74    args, unknown = parse_args()
75    if args.only:
76        name = args.only
77        if name in TIMM_MODEL_NAMES:
78            import timm_models
79
80            timm_models.timm_main()
81        elif name in HF_MODELS_FILE_NAME:
82            import huggingface
83
84            huggingface.huggingface_main()
85        elif name in TORCHBENCH_MODELS_FILE_NAME:
86            import torchbench
87
88            torchbench.torchbench_main()
89        else:
90            print(f"Illegal model name? {name}")
91            sys.exit(-1)
92    else:
93        import torchbench
94
95        torchbench.torchbench_main()
96
97        import huggingface
98
99        huggingface.huggingface_main()
100
101        import timm_models
102
103        timm_models.timm_main()
104