xref: /aosp_15_r20/external/pytorch/benchmarks/serialization/simple_measurement.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from pyarkbench import Benchmark, default_args, Timer
2
3import torch
4
5
6use_new = True
7
8
9class Basic(Benchmark):
10    def benchmark(self):
11        x = [torch.ones(200, 200) for i in range(30)]
12        with Timer() as big1:
13            torch.save(x, "big_tensor.zip", _use_new_zipfile_serialization=use_new)
14
15        with Timer() as big2:
16            v = torch.load("big_tensor.zip")
17
18        x = [torch.ones(10, 10) for i in range(200)]
19        with Timer() as small1:
20            torch.save(x, "small_tensor.zip", _use_new_zipfile_serialization=use_new)
21
22        with Timer() as small2:
23            v = torch.load("small_tensor.zip")
24
25        return {
26            "Big Tensors Save": big1.ms_duration,
27            "Big Tensors Load": big2.ms_duration,
28            "Small Tensors Save": small1.ms_duration,
29            "Small Tensors Load": small2.ms_duration,
30        }
31
32
33if __name__ == "__main__":
34    bench = Basic(*default_args.bench())
35    print("Use zipfile serialization:", use_new)
36    results = bench.run()
37    bench.print_stats(results, stats=["mean", "median"])
38