xref: /aosp_15_r20/external/pytorch/torch/_inductor/wrapper_benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import tempfile
4from collections import defaultdict
5
6import torch
7from torch.autograd import DeviceType
8
9from .runtime.benchmarking import benchmarker
10from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
11
12
13_kernel_category_choices = [
14    "foreach",
15    "persistent_reduction",
16    "pointwise",
17    "reduction",
18    "split_scan",
19    "template",
20]
21
22
23def get_kernel_category_by_source_code(src_code):
24    """
25    Similar to get_kernel_category but use the source code. Call this API
26    if we have not compile the src_code to module yet.
27    """
28    choices = [
29        ch for ch in _kernel_category_choices if f"@triton_heuristics.{ch}" in src_code
30    ]
31    if len(choices) == 1:
32        return choices[0]
33    else:
34        return "unknown"
35
36
37def get_kernel_category(kernel_mod):
38    """
39    Given the module defining a triton kernel, return the category of the kernel.
40    Category can be one of:
41    - pointwise
42    - reduction
43    - persistent_reduction
44
45    Currently we simply decide the category depending on what decorator is imported
46    by the kernel.
47    """
48    choices = [ch for ch in _kernel_category_choices if ch in kernel_mod.__dict__]
49    if len(choices) == 1:
50        return choices[0]
51    else:
52        return "unknown"
53
54
55def get_triton_kernel(mod):
56    from torch._inductor.runtime.triton_heuristics import CachingAutotuner
57
58    cand_list = [
59        v
60        for k, v in mod.__dict__.items()
61        if k.startswith("triton_") and isinstance(v, CachingAutotuner)
62    ]
63    assert len(cand_list) == 1
64    return cand_list[0]
65
66
67def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
68    """
69    An experimental API used only when config.benchmark_kernel is true.
70
71    Run the kernel benchmarks for all the kernels cached in PyCodeCache.
72    Used in the compiled modules.
73
74    Put this method here rather than codegen it for convenience since its implementation
75    does not change based on different graph modules being compiled.
76    """
77    from torch._inductor.codecache import PyCodeCache
78
79    nfound = 0
80    for kernel_key, kernel_mod in PyCodeCache.cache.items():
81        if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
82            continue
83
84        triton_kernel = get_triton_kernel(kernel_mod)
85        kernel_category = get_kernel_category(kernel_mod)
86        args = kernel_mod.get_args()
87        num_in_out_ptrs = len(
88            [
89                arg_name
90                for arg_name in triton_kernel.fn.arg_names
91                if arg_name.startswith("in_out_ptr")
92            ]
93        )
94        num_gb = triton_kernel.inductor_meta.get("kernel_num_gb", None)
95        if num_gb is None:
96            num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
97
98        def get_info_str(ms, n_regs, n_spills, shared, prefix=""):
99            if not any(x is None for x in [n_regs, n_spills, shared]):
100                kernel_detail_str = (
101                    f"  {n_regs:3} regs  {n_spills:3} spills  {shared:8} shared mem"
102                )
103            else:
104                kernel_detail_str = ""
105
106            gb_per_s = num_gb / (ms / 1e3)
107            return create_bandwidth_info_str(
108                ms, num_gb, gb_per_s, prefix=prefix, suffix=kernel_detail_str
109            )
110
111        kernel_desc = (
112            f"{benchmark_name:20} {kernel_category[:3].upper()} {kernel_key[:10]}"
113        )
114        if benchmark_all_configs:
115            assert hasattr(kernel_mod, "benchmark_all_configs")
116            bench_result = kernel_mod.benchmark_all_configs(args)
117            print(kernel_desc)
118            for launcher, ms in bench_result.items():
119                print(
120                    f"  {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
121                )
122        else:
123            ms = benchmarker.benchmark_gpu(
124                lambda: kernel_mod.call(args), rep=40, fast_flush=True
125            )
126            assert (
127                len(triton_kernel.launchers) == 1
128            ), "Autotuner should have selected the best config"
129            launcher = triton_kernel.launchers[0]
130            print(
131                get_info_str(
132                    ms,
133                    launcher.n_regs,
134                    launcher.n_spills,
135                    launcher.shared,
136                    prefix=f"{kernel_desc} ",
137                )
138            )
139
140        nfound += 1
141    if nfound == 0:
142        print(
143            "No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True"
144        )
145
146
147@dataclasses.dataclass
148class ProfileEvent:
149    category: str
150    key: str
151    self_device_time_ms: float
152    # the benchmark is run multiple times and we average the count across all the
153    # runs. It should be an integer but define a float just in case.
154    count: float
155
156
157def parse_profile_event_list(
158    benchmark_name, event_list, wall_time_ms, nruns, device_name
159):
160    def get_self_device_time(ev):
161        """
162        ev.self_device_time_total is in microsecond. Convert to millisecond.
163        """
164        return ev.self_device_time_total / 1000 / nruns
165
166    all_events = defaultdict(list)
167
168    def add_event(ev, category):
169        profile_ev = ProfileEvent(
170            category=category,
171            key=ev.key,
172            self_device_time_ms=get_self_device_time(ev),
173            count=ev.count / nruns,  # average across all runs
174        )
175        all_events[category].append(profile_ev)
176
177    for ev in event_list:
178        assert not ev.is_legacy, "Don't support the legacy profiler"
179        if ev.device_type == DeviceType.CPU:
180            # ignore the event on CPU side
181            continue
182
183        category = "unknown"
184        if ev.key.startswith("triton_"):
185            if ev.key.startswith("triton_poi"):
186                category = "triton_pointwise"
187            elif ev.key.startswith("triton_red"):
188                category = "triton_reduction"
189            elif ev.key.startswith("triton_per"):
190                category = "triton_persistent_reduction"
191            else:
192                category = "triton_unknown"
193
194        add_event(ev, category)
195
196    def report_category(category, profile_events):
197        from tabulate import tabulate
198
199        profile_events.sort(key=lambda ev: ev.self_device_time_ms, reverse=True)
200
201        rows = []
202        total_time = 0.0
203        print(f"\n  == {category} category kernels == ")
204        for ev in profile_events:
205            total_time += ev.self_device_time_ms
206            percent = f"{ev.self_device_time_ms / wall_time_ms * 100:.2f}%"
207            rows.append([ev.key[:120], ev.self_device_time_ms, ev.count, percent])
208        rows.append(
209            ["Total", total_time, "", f"{total_time / wall_time_ms * 100:.2f}%"]
210        )
211        print(
212            tabulate(
213                rows,
214                headers=[
215                    "Kernel",
216                    f"Self {device_name.upper()} TIME (ms)",
217                    "Count",
218                    "Percent",
219                ],
220            )
221        )
222        return total_time
223
224    def report():
225        category_list = [
226            "triton_pointwise",
227            "triton_reduction",
228            "triton_persistent_reduction",
229            "triton_unknown",
230            "unknown",
231        ]
232        assert set(all_events.keys()).issubset(
233            set(category_list)
234        ), f"{list(all_events.keys())}"
235
236        per_category_wall_time = {}
237        total_device_ms = 0.0
238        for category in category_list:
239            if category in all_events:
240                _time = report_category(category, all_events[category])
241                per_category_wall_time[category] = _time
242                total_device_ms += _time
243
244        device_busy_percent = f"{total_device_ms / wall_time_ms * 100:.2f}%"
245        print(
246            f"\nPercent of time when {device_name.upper()} is busy: {device_busy_percent}"
247        )
248        print(f"Total wall time {wall_time_ms:.3f} ms")
249
250        # output such a line so we can gather such line from all compiled modules from all
251        # benchmarks and tabulate it!
252        # Columns: benchmark_name, pointwise_percent, reduction_percent, persistent_reduction_percent,
253        #   unknown_category_percent, device_busy_percent, wall_time_ms
254        tabulate_line = f"Output for tabulate: {benchmark_name}"
255        for category in category_list:
256            percent = (
257                f"{per_category_wall_time.get(category, 0.0) / wall_time_ms * 100:.2f}%"
258            )
259            tabulate_line += f", {percent}"
260        tabulate_line += f", {device_busy_percent}, {wall_time_ms:.3f}ms"
261
262        print(tabulate_line)
263
264    report()
265
266
267def compiled_module_main(benchmark_name, benchmark_compiled_module_fn):
268    """
269    This is the function called in __main__ block of a compiled module.
270    """
271    import argparse
272
273    parser = argparse.ArgumentParser()
274    parser.add_argument(
275        "--benchmark-kernels",
276        "-k",
277        action="store_true",
278        help="Whether to benchmark each individual kernels",
279    )
280    parser.add_argument(
281        "--benchmark-all-configs",
282        "-c",
283        action="store_true",
284        help="Whether to benchmark each individual config for a kernel",
285    )
286    parser.add_argument(
287        "--profile",
288        "-p",
289        action="store_true",
290        help="Whether to profile the compiled module",
291    )
292    args = parser.parse_args()
293
294    if args.benchmark_kernels:
295        benchmark_all_kernels(benchmark_name, args.benchmark_all_configs)
296    else:
297        times = 10
298        repeat = 10
299        wall_time_ms = benchmark_compiled_module_fn(times=times, repeat=repeat) * 1000
300
301        if not args.profile:
302            return
303
304        with torch.profiler.profile(record_shapes=True) as p:
305            benchmark_compiled_module_fn(times=times, repeat=repeat)
306
307        path = f"{tempfile.gettempdir()}/compiled_module_profile.json"
308        p.export_chrome_trace(path)
309        print(f"Profiling result for a compiled module of benchmark {benchmark_name}:")
310        print(f"Chrome trace for the profile is written to {path}")
311        event_list = p.key_averages(group_by_input_shape=True)
312        print(event_list.table(sort_by="self_device_time_total", row_limit=10))
313        parse_profile_event_list(
314            benchmark_name, event_list, wall_time_ms, times * repeat, p.use_device
315        )
316