xref: /aosp_15_r20/external/pytorch/benchmarks/overrides_benchmark/bench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import time
3
4from common import SubTensor, SubWithTorchFunction, WithTorchFunction
5
6import torch
7
8
9NUM_REPEATS = 1000
10NUM_REPEAT_OF_REPEATS = 1000
11
12
13def bench(t1, t2):
14    bench_times = []
15    for _ in range(NUM_REPEAT_OF_REPEATS):
16        time_start = time.time()
17        for _ in range(NUM_REPEATS):
18            torch.add(t1, t2)
19        bench_times.append(time.time() - time_start)
20
21    bench_time = float(torch.min(torch.tensor(bench_times))) / 1000
22    bench_std = float(torch.std(torch.tensor(bench_times))) / 1000
23
24    return bench_time, bench_std
25
26
27def main():
28    global NUM_REPEATS
29    global NUM_REPEAT_OF_REPEATS
30
31    parser = argparse.ArgumentParser(
32        description="Run the __torch_function__ benchmarks."
33    )
34    parser.add_argument(
35        "--nreps",
36        "-n",
37        type=int,
38        default=NUM_REPEATS,
39        help="The number of repeats for one measurement.",
40    )
41    parser.add_argument(
42        "--nrepreps",
43        "-m",
44        type=int,
45        default=NUM_REPEAT_OF_REPEATS,
46        help="The number of measurements.",
47    )
48    args = parser.parse_args()
49
50    NUM_REPEATS = args.nreps
51    NUM_REPEAT_OF_REPEATS = args.nrepreps
52
53    types = torch.tensor, SubTensor, WithTorchFunction, SubWithTorchFunction
54
55    for t in types:
56        tensor_1 = t([1.0])
57        tensor_2 = t([2.0])
58
59        bench_min, bench_std = bench(tensor_1, tensor_2)
60        print(
61            f"Type {t.__name__} had a minimum time of {10**6 * bench_min} us"
62            f" and a standard deviation of {(10**6) * bench_std} us."
63        )
64
65
66if __name__ == "__main__":
67    main()
68