xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/conftest.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import pytest  # noqa: F401
2
3
4default_rnns = [
5    "cudnn",
6    "aten",
7    "jit",
8    "jit_premul",
9    "jit_premul_bias",
10    "jit_simple",
11    "jit_multilayer",
12    "py",
13]
14default_cnns = ["resnet18", "resnet18_jit", "resnet50", "resnet50_jit"]
15all_nets = default_rnns + default_cnns
16
17
18def pytest_generate_tests(metafunc):
19    # This creates lists of tests to generate, can be customized
20    if metafunc.cls.__name__ == "TestBenchNetwork":
21        metafunc.parametrize("net_name", all_nets, scope="class")
22        metafunc.parametrize(
23            "executor", [metafunc.config.getoption("executor")], scope="class"
24        )
25        metafunc.parametrize(
26            "fuser", [metafunc.config.getoption("fuser")], scope="class"
27        )
28
29
30def pytest_addoption(parser):
31    parser.addoption("--fuser", default="old", help="fuser to use for benchmarks")
32    parser.addoption(
33        "--executor", default="legacy", help="executor to use for benchmarks"
34    )
35