1#!/usr/bin/env python3 2 3import argparse 4import os 5import sys 6from typing import Set 7 8 9# Note - hf and timm have their own version of this, torchbench does not 10# TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this... 11def model_names(filename: str) -> Set[str]: 12 names = set() 13 with open(filename) as fh: 14 lines = fh.readlines() 15 lines = [line.rstrip() for line in lines] 16 for line in lines: 17 line_parts = line.split(" ") 18 if len(line_parts) == 1: 19 line_parts = line.split(",") 20 model_name = line_parts[0] 21 names.add(model_name) 22 return names 23 24 25TIMM_MODEL_NAMES = model_names( 26 os.path.join(os.path.dirname(__file__), "timm_models_list.txt") 27) 28HF_MODELS_FILE_NAME = model_names( 29 os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt") 30) 31TORCHBENCH_MODELS_FILE_NAME = model_names( 32 os.path.join(os.path.dirname(__file__), "all_torchbench_models_list.txt") 33) 34 35# timm <> HF disjoint 36assert TIMM_MODEL_NAMES.isdisjoint(HF_MODELS_FILE_NAME) 37# timm <> torch disjoint 38assert TIMM_MODEL_NAMES.isdisjoint(TORCHBENCH_MODELS_FILE_NAME) 39# torch <> hf disjoint 40assert TORCHBENCH_MODELS_FILE_NAME.isdisjoint(HF_MODELS_FILE_NAME) 41 42 43def parse_args(args=None): 44 parser = argparse.ArgumentParser() 45 parser.add_argument( 46 "--only", 47 help="""Run just one model from whichever model suite it belongs to. Or 48 specify the path and class name of the model in format like: 49 --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME> 50 51 Due to the fact that dynamo changes current working directory, 52 the path should be an absolute path. 53 54 The class should have a method get_example_inputs to return the inputs 55 for the model. An example looks like 56 ``` 57 class LinearModel(nn.Module): 58 def __init__(self): 59 super().__init__() 60 self.linear = nn.Linear(10, 10) 61 62 def forward(self, x): 63 return self.linear(x) 64 65 def get_example_inputs(self): 66 return (torch.randn(2, 10),) 67 ``` 68 """, 69 ) 70 return parser.parse_known_args(args) 71 72 73if __name__ == "__main__": 74 args, unknown = parse_args() 75 if args.only: 76 name = args.only 77 if name in TIMM_MODEL_NAMES: 78 import timm_models 79 80 timm_models.timm_main() 81 elif name in HF_MODELS_FILE_NAME: 82 import huggingface 83 84 huggingface.huggingface_main() 85 elif name in TORCHBENCH_MODELS_FILE_NAME: 86 import torchbench 87 88 torchbench.torchbench_main() 89 else: 90 print(f"Illegal model name? {name}") 91 sys.exit(-1) 92 else: 93 import torchbench 94 95 torchbench.torchbench_main() 96 97 import huggingface 98 99 huggingface.huggingface_main() 100 101 import timm_models 102 103 timm_models.timm_main() 104