xref: /aosp_15_r20/external/pytorch/test/inductor/test_benchmarking.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2
3import unittest
4
5import torch
6from torch._dynamo.utils import counters
7from torch._inductor.runtime.benchmarking import Benchmarker, TritonBenchmarker
8from torch._inductor.test_case import run_tests, TestCase
9from torch.testing._internal.common_utils import (
10    decorateIf,
11    instantiate_parametrized_tests,
12    parametrize,
13)
14from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
15
16
17ALL_BENCHMARKER_CLASSES = (
18    Benchmarker,
19    TritonBenchmarker,
20)
21
22
23@instantiate_parametrized_tests
24class TestBenchmarker(TestCase):
25    def setUp(self):
26        super().setUp()
27        torch.manual_seed(12345)
28        counters.clear()
29
30    @staticmethod
31    def get_counter_value(benchmarker_cls, fn_name):
32        return counters["inductor"][
33            f"benchmarking.{benchmarker_cls.__name__}.{fn_name}"
34        ]
35
36    @staticmethod
37    def make_params(device, size=100):
38        fn, fn_args, fn_kwargs = torch.sum, (torch.randn(size, device=device),), {}
39        _callable = lambda: fn(*fn_args, **fn_kwargs)  # noqa: E731
40        return (fn, fn_args, fn_kwargs), _callable
41
42    @unittest.skipIf(not HAS_CPU or not HAS_GPU, "requires CPU and GPU")
43    @decorateIf(
44        unittest.expectedFailure,
45        lambda params: params["benchmarker_cls"] is Benchmarker
46        and params["device"] == GPU_TYPE,
47    )
48    @parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
49    @parametrize("device", (GPU_TYPE, "cpu"))
50    def test_benchmark_smoke(self, benchmarker_cls, device):
51        benchmarker = benchmarker_cls()
52        (fn, fn_args, fn_kwargs), _ = self.make_params(device)
53        timing = benchmarker.benchmark(fn, fn_args, fn_kwargs)
54        self.assertGreater(timing, 0)
55        self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark"), 1)
56        self.assertEqual(
57            self.get_counter_value(
58                benchmarker_cls, "benchmark_cpu" if device == "cpu" else "benchmark_gpu"
59            ),
60            1,
61        )
62
63    @unittest.skipIf(not HAS_CPU, "requires CPU")
64    @parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
65    def test_benchmark_cpu_smoke(self, benchmarker_cls, device="cpu"):
66        benchmarker = benchmarker_cls()
67        _, _callable = self.make_params(device)
68        timing = benchmarker.benchmark_cpu(_callable)
69        self.assertGreater(timing, 0)
70        self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_cpu"), 1)
71
72    @unittest.skipIf(not HAS_GPU, "requires GPU")
73    @decorateIf(
74        unittest.expectedFailure,
75        lambda params: params["benchmarker_cls"] is Benchmarker,
76    )
77    @parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
78    def test_benchmark_gpu_smoke(self, benchmarker_cls, device=GPU_TYPE):
79        benchmarker = benchmarker_cls()
80        _, _callable = self.make_params(device)
81        timing = benchmarker.benchmark_gpu(_callable)
82        self.assertGreater(timing, 0)
83        self.assertEqual(self.get_counter_value(benchmarker_cls, "benchmark_gpu"), 1)
84        if benchmarker_cls is TritonBenchmarker:
85            self.assertEqual(
86                self.get_counter_value(benchmarker_cls, "triton_do_bench"), 1
87            )
88
89    @unittest.skipIf(not HAS_CPU and not HAS_GPU, "requires CPU or GPU")
90    @unittest.expectedFailure
91    @parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
92    def test_benchmark_safely_infers_device_no_devices(
93        self, benchmarker_cls, device="cpu" if HAS_CPU else GPU_TYPE
94    ):
95        benchmarker = benchmarker_cls()
96        (fn, _, _), _ = self.make_params(device)
97        benchmarker.benchmark(fn, (), {})
98
99    @unittest.skipIf(not HAS_CPU or not HAS_GPU, "requires CPU and GPU")
100    @unittest.expectedFailure
101    @parametrize("benchmarker_cls", ALL_BENCHMARKER_CLASSES)
102    def test_benchmark_safely_infers_device_many_devices(self, benchmarker_cls):
103        benchmarker = benchmarker_cls()
104        (fn, cpu_args, cpu_kwargs), _ = self.make_sum("cpu")
105        (_, gpu_args, gpu_kwargs), _ = self.make_sum(GPU_TYPE)
106        many_devices_args = cpu_args + gpu_args
107        many_devices_kwargs = cpu_kwargs
108        many_devices_kwargs.update(gpu_kwargs)
109        benchmarker.benchmark(fn, many_devices_args, many_devices_kwargs)
110
111
112if __name__ == "__main__":
113    run_tests()
114