1import pytest # noqa: F401 2 3 4default_rnns = [ 5 "cudnn", 6 "aten", 7 "jit", 8 "jit_premul", 9 "jit_premul_bias", 10 "jit_simple", 11 "jit_multilayer", 12 "py", 13] 14default_cnns = ["resnet18", "resnet18_jit", "resnet50", "resnet50_jit"] 15all_nets = default_rnns + default_cnns 16 17 18def pytest_generate_tests(metafunc): 19 # This creates lists of tests to generate, can be customized 20 if metafunc.cls.__name__ == "TestBenchNetwork": 21 metafunc.parametrize("net_name", all_nets, scope="class") 22 metafunc.parametrize( 23 "executor", [metafunc.config.getoption("executor")], scope="class" 24 ) 25 metafunc.parametrize( 26 "fuser", [metafunc.config.getoption("fuser")], scope="class" 27 ) 28 29 30def pytest_addoption(parser): 31 parser.addoption("--fuser", default="old", help="fuser to use for benchmarks") 32 parser.addoption( 33 "--executor", default="legacy", help="executor to use for benchmarks" 34 ) 35