xref: /aosp_15_r20/external/pytorch/test/inductor/test_inductor_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import functools
4import logging
5
6import torch
7from torch._inductor.runtime.benchmarking import benchmarker
8from torch._inductor.test_case import run_tests, TestCase
9from torch._inductor.utils import do_bench_using_profiling
10
11
12log = logging.getLogger(__name__)
13
14
15class TestBench(TestCase):
16    @classmethod
17    def setUpClass(cls):
18        super().setUpClass()
19        x = torch.rand(1024, 10).cuda().half()
20        w = torch.rand(512, 10).cuda().half()
21        cls._bench_fn = functools.partial(torch.nn.functional.linear, x, w)
22
23    def test_benchmarker(self):
24        res = benchmarker.benchmark_gpu(self._bench_fn)
25        log.warning("do_bench result: %s", res)
26        self.assertGreater(res, 0)
27
28    def test_do_bench_using_profiling(self):
29        res = do_bench_using_profiling(self._bench_fn)
30        log.warning("do_bench_using_profiling result: %s", res)
31        self.assertGreater(res, 0)
32
33
34if __name__ == "__main__":
35    run_tests("cuda")
36