xref: /aosp_15_r20/external/pytorch/test/test_throughput_benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TemporaryFileName, TestCase
5*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import ThroughputBenchmark
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerclass TwoLayerNet(torch.jit.ScriptModule):
9*da0073e9SAndroid Build Coastguard Worker    def __init__(self, D_in, H, D_out):
10*da0073e9SAndroid Build Coastguard Worker        super().__init__()
11*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(D_in, H)
12*da0073e9SAndroid Build Coastguard Worker        self.linear2 = torch.nn.Linear(2 * H, D_out)
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
15*da0073e9SAndroid Build Coastguard Worker    def forward(self, x1, x2):
16*da0073e9SAndroid Build Coastguard Worker        h1_relu = self.linear1(x1).clamp(min=0)
17*da0073e9SAndroid Build Coastguard Worker        h2_relu = self.linear1(x2).clamp(min=0)
18*da0073e9SAndroid Build Coastguard Worker        cat = torch.cat((h1_relu, h2_relu), 1)
19*da0073e9SAndroid Build Coastguard Worker        y_pred = self.linear2(cat)
20*da0073e9SAndroid Build Coastguard Worker        return y_pred
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Workerclass TwoLayerNetModule(torch.nn.Module):
24*da0073e9SAndroid Build Coastguard Worker    def __init__(self, D_in, H, D_out):
25*da0073e9SAndroid Build Coastguard Worker        super().__init__()
26*da0073e9SAndroid Build Coastguard Worker        self.linear1 = torch.nn.Linear(D_in, H)
27*da0073e9SAndroid Build Coastguard Worker        self.linear2 = torch.nn.Linear(2 * H, D_out)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def forward(self, x1, x2):
30*da0073e9SAndroid Build Coastguard Worker        h1_relu = self.linear1(x1).clamp(min=0)
31*da0073e9SAndroid Build Coastguard Worker        h2_relu = self.linear1(x2).clamp(min=0)
32*da0073e9SAndroid Build Coastguard Worker        cat = torch.cat((h1_relu, h2_relu), 1)
33*da0073e9SAndroid Build Coastguard Worker        y_pred = self.linear2(cat)
34*da0073e9SAndroid Build Coastguard Worker        return y_pred
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerclass TestThroughputBenchmark(TestCase):
38*da0073e9SAndroid Build Coastguard Worker    def linear_test(self, Module, profiler_output_path=""):
39*da0073e9SAndroid Build Coastguard Worker        D_in = 10
40*da0073e9SAndroid Build Coastguard Worker        H = 5
41*da0073e9SAndroid Build Coastguard Worker        D_out = 15
42*da0073e9SAndroid Build Coastguard Worker        B = 8
43*da0073e9SAndroid Build Coastguard Worker        NUM_INPUTS = 2
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        module = Module(D_in, H, D_out)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        inputs = []
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        for i in range(NUM_INPUTS):
50*da0073e9SAndroid Build Coastguard Worker            inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)])
51*da0073e9SAndroid Build Coastguard Worker        bench = ThroughputBenchmark(module)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker        for input in inputs:
54*da0073e9SAndroid Build Coastguard Worker            # can do both args and kwargs here
55*da0073e9SAndroid Build Coastguard Worker            bench.add_input(input[0], x2=input[1])
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        for i in range(NUM_INPUTS):
58*da0073e9SAndroid Build Coastguard Worker            # or just unpack the list of inputs
59*da0073e9SAndroid Build Coastguard Worker            module_result = module(*inputs[i])
60*da0073e9SAndroid Build Coastguard Worker            bench_result = bench.run_once(*inputs[i])
61*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(bench_result, module_result)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker        stats = bench.benchmark(
64*da0073e9SAndroid Build Coastguard Worker            num_calling_threads=4,
65*da0073e9SAndroid Build Coastguard Worker            num_warmup_iters=100,
66*da0073e9SAndroid Build Coastguard Worker            num_iters=1000,
67*da0073e9SAndroid Build Coastguard Worker            profiler_output_path=profiler_output_path,
68*da0073e9SAndroid Build Coastguard Worker        )
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        print(stats)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    def test_script_module(self):
73*da0073e9SAndroid Build Coastguard Worker        self.linear_test(TwoLayerNet)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def test_module(self):
76*da0073e9SAndroid Build Coastguard Worker        self.linear_test(TwoLayerNetModule)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    def test_profiling(self):
79*da0073e9SAndroid Build Coastguard Worker        with TemporaryFileName() as fname:
80*da0073e9SAndroid Build Coastguard Worker            self.linear_test(TwoLayerNetModule, profiler_output_path=fname)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
84*da0073e9SAndroid Build Coastguard Worker    run_tests()
85