xref: /aosp_15_r20/external/pytorch/benchmarks/record_function_benchmark/record_function_bench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import sys
3
4from benchmarks.fastrnns.factory import lstm_creator
5
6from torchvision.models import resnet50
7
8import torch
9import torch.utils.benchmark as benchmark_utils
10
11
12def prepare_lstm_jit(bench_args):
13    model_def = lstm_creator(
14        script=True,
15        seqLength=bench_args.lstmSeqLength,
16        numLayers=bench_args.lstmNumLayers,
17        inputSize=bench_args.lstmInputSize,
18        hiddenSize=bench_args.lstmHiddenSize,
19        miniBatch=bench_args.lstmMiniBatch,
20        device="cpu",
21    )
22    return model_def.inputs, model_def.forward
23
24
25def prepare_resnet50_jit(bench_args):
26    model = resnet50()
27    inputs = (torch.randn(32, 3, 224, 224),)
28    model = torch.jit.trace(model, inputs)
29    return inputs, model
30
31
32MODELS = {
33    "resnet50_jit": prepare_resnet50_jit,
34    "lstm_jit": prepare_lstm_jit,
35}
36
37NUM_THREADS = [1, 2, 4, 8, 16, 32]
38
39
40def run_bench(model_names, bench_args):
41    results = []
42    for model_name in model_names:
43        model_creator = MODELS[model_name]
44        inputs, model = model_creator(bench_args)
45
46        print("Benchmarking RecordFunction overhead for", model_name)
47        print("Running warmup...", end=" ")
48        sys.stdout.flush()
49        for _ in range(bench_args.warmup):
50            model(*inputs)
51        print("finished")
52
53        for num_threads in NUM_THREADS:
54            for with_rec_fn in [True, False]:
55                torch.autograd._enable_record_function(with_rec_fn)
56                torch.autograd._clear_callbacks()
57                if with_rec_fn:
58                    torch.autograd._set_empty_test_observer(True, 0.0001)
59
60                print(
61                    "Running {} RecordFunction, num threads {} ...".format(
62                        "with" if with_rec_fn else "without", num_threads
63                    ),
64                    end=" ",
65                )
66                sys.stdout.flush()
67                timer = benchmark_utils.Timer(
68                    stmt="model(*inputs)",
69                    globals={"model": model, "inputs": inputs},
70                    description=model_name,
71                    label="Record function overhead",
72                    sub_label=f"with{'' if with_rec_fn else 'out'}_rec_fn, num_threads {num_threads}",
73                    num_threads=num_threads,
74                )
75                result = timer.blocked_autorange(
76                    min_run_time=bench_args.timer_min_run_time
77                )
78                print("finished")
79                print(result)
80                sys.stdout.flush()
81                results.append(result)
82
83    comparison = benchmark_utils.Compare(results)
84    comparison.trim_significant_figures()
85    comparison.highlight_warnings()
86    comparison.print()
87
88
89if __name__ == "__main__":
90    parser = argparse.ArgumentParser(
91        description="Benchmark RecordFunction overhead for ResNet and LSTM models"
92    )
93
94    parser.add_argument(
95        "--models",
96        nargs="*",
97        default=["lstm_jit"],
98        help="What model to run: " + str(MODELS.keys()),
99    )
100
101    parser.add_argument("--lstmSeqLength", default="100", type=int)
102    parser.add_argument("--lstmNumLayers", default="1", type=int)
103    parser.add_argument("--lstmInputSize", default="512", type=int)
104    parser.add_argument("--lstmHiddenSize", default="512", type=int)
105    parser.add_argument("--lstmMiniBatch", default="64", type=int)
106    parser.add_argument("--warmup", default="2", type=int)
107    parser.add_argument("--nloops", default="50", type=int)
108    parser.add_argument(
109        "--timer-min-run-time", "--timer_min_run_time", default=120, type=int
110    )
111
112    args = parser.parse_args()
113
114    models = args.models or MODELS.keys()
115
116    for model in models:
117        assert model in MODELS
118    run_bench(models, args)
119