xref: /aosp_15_r20/external/pytorch/benchmarks/nested/nested_bmm_bench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import random
3
4import torch
5
6
7def bench(nt_a, nt_b, niter):
8    # Warmup
9    nt_c = nt_a.bmm(nt_b)
10
11    torch.cuda.synchronize()
12    start_event = torch.cuda.Event(enable_timing=True)
13    end_event = torch.cuda.Event(enable_timing=True)
14    start_event.record()
15    for iter in range(niter):
16        nt_c = nt_a.bmm(nt_b)
17    end_event.record()
18    torch.cuda.synchronize()
19    runtime = (start_event.elapsed_time(end_event)) / niter
20    return runtime
21
22
23def sweep_n(niter, dtype):
24    for ntensor in [4, 8, 16, 32, 64, 128, 256]:
25        tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)]
26        nt_a = torch.nested.nested_tensor(
27            tensors,
28            dtype=dtype,
29            device="cuda",
30        )
31        nt_b = torch.nested.nested_tensor(
32            [t.t() for t in tensors],
33            dtype=dtype,
34            device="cuda",
35        )
36        runtime = bench(nt_a, nt_b, niter)
37        nt_a_size = torch.ops.aten._nested_tensor_size(nt_a)
38        lengths = nt_a_size[:, 1]
39        print(
40            ",".join(
41                map(
42                    str,
43                    [
44                        ntensor,
45                        dtype,
46                        lengths.min().item(),
47                        lengths.float().mean().item(),
48                        lengths.max().item(),
49                        runtime,
50                    ],
51                )
52            )
53        )
54
55
56if __name__ == "__main__":
57    random.seed(123)
58    parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
59    parser.add_argument("--niter", default="10", type=int)
60
61    args = parser.parse_args()
62    niter = args.niter
63
64    print("ntensor,dtype,min_length,mean_length,max_length,runtime")
65    sweep_n(niter, torch.float32)
66    sweep_n(niter, torch.float16)
67