xref: /aosp_15_r20/external/pytorch/torch/_inductor/metrics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import csv
5import dataclasses
6import inspect
7import os
8import re
9from dataclasses import dataclass
10from functools import lru_cache
11from typing import Dict, List, Set, Tuple, TYPE_CHECKING
12
13from torch._inductor import config
14from torch._inductor.utils import get_benchmark_name
15
16
17# Prevent circular import
18if TYPE_CHECKING:
19    from torch._inductor.scheduler import BaseSchedulerNode
20
21# counter for tracking how many kernels have been generated
22generated_kernel_count = 0
23generated_cpp_vec_kernel_count = 0
24num_bytes_accessed = 0
25nodes_num_elem: List[
26    Tuple[
27        BaseSchedulerNode,
28        int,
29    ]
30] = []
31node_runtimes: List[Tuple[BaseSchedulerNode, float]] = []
32
33# counters for tracking fusions
34ir_nodes_pre_fusion = 0
35
36# counters for tracking to_dtype inserted
37cpp_to_dtype_count = 0
38
39
40@dataclasses.dataclass
41class CppOuterLoopFusedCount:
42    inner_kernel_number: int
43    local_buffer_number: int = 0
44
45
46# The length counts the number of outer loop fusions.
47cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
48
49num_comprehensive_padding = 0
50num_matches_for_scatter_upon_const_tensor = 0
51
52num_loop_reordering = 0
53
54
55# reset all counters
56def reset():
57    global generated_kernel_count
58    global generated_cpp_vec_kernel_count
59    global num_bytes_accessed, nodes_num_elem
60    global ir_nodes_pre_fusion
61    global cpp_to_dtype_count
62    global cpp_outer_loop_fused_inner_counts
63    global num_comprehensive_padding
64    global num_matches_for_scatter_upon_const_tensor
65    global num_loop_reordering
66
67    generated_kernel_count = 0
68    generated_cpp_vec_kernel_count = 0
69    num_bytes_accessed = 0
70    nodes_num_elem.clear()
71    node_runtimes.clear()
72    ir_nodes_pre_fusion = 0
73    cpp_to_dtype_count = 0
74    cpp_outer_loop_fused_inner_counts.clear()
75    num_comprehensive_padding = 0
76    num_matches_for_scatter_upon_const_tensor = 0
77    num_loop_reordering = 0
78
79
80@dataclass
81class CachedMetricsDeltas:
82    """
83    The subset of metrics we want update across cache hits, e.g., the
84    FxGraphCache.
85    """
86
87    generated_kernel_count: int
88    generated_cpp_vec_kernel_count: int
89    ir_nodes_pre_fusion: int
90    cpp_to_dtype_count: int
91    num_bytes_accessed: int
92    num_matches_for_scatter_upon_const_tensor: int
93
94
95def get_metric_fields():
96    return [field.name for field in dataclasses.fields(CachedMetricsDeltas)]
97
98
99class CachedMetricsHelper:
100    """
101    A helper class to help calculate and apply counter deltas for those
102    metrics we want to save with cache entries (e.g., FxGraphCache) and
103    apply on a cache hit.
104    """
105
106    def __init__(self) -> None:
107        self.cached_metrics = {}
108        for metric in get_metric_fields():
109            self.cached_metrics[metric] = globals()[metric]
110
111    def get_deltas(self) -> CachedMetricsDeltas:
112        delta_metrics = {}
113        for metric in get_metric_fields():
114            delta_metrics[metric] = globals()[metric] - self.cached_metrics[metric]
115
116        return CachedMetricsDeltas(**delta_metrics)
117
118    @staticmethod
119    def apply_deltas(delta: CachedMetricsDeltas):
120        for metric in get_metric_fields():
121            globals()[metric] += getattr(delta, metric)
122
123
124REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {}
125
126
127@dataclass
128class MetricTable:
129    table_name: str
130    column_names: List[str]
131
132    num_rows_added: int = 0
133
134    def add_row(self, row_fn):
135        if self.table_name not in enabled_metric_tables():
136            return
137
138        row_dict = row_fn()
139        assert len(self.column_names) == len(
140            row_dict
141        ), f"{len(self.column_names)} v.s. {len(row_dict)}"
142        assert set(self.column_names) == set(
143            row_dict.keys()
144        ), f"{set(self.column_names)} v.s. {set(row_dict.keys())}"
145
146        row = [
147            get_benchmark_name(),
148        ]
149        row += [row_dict[column_name] for column_name in self.column_names]
150        self._write_row(row)
151
152    def output_filename(self):
153        return f"metric_table_{self.table_name}.csv"
154
155    def write_header(self):
156        filename = self.output_filename()
157        with open(filename, "w") as fd:
158            writer = csv.writer(fd, lineterminator="\n")
159            writer.writerow(["model_name"] + self.column_names)
160
161    def _write_row(self, row):
162        filename = self.output_filename()
163        if self.num_rows_added == 0 and not os.path.exists(filename):
164            self.write_header()
165
166        self.num_rows_added += 1
167
168        for idx, orig_val in enumerate(row):
169            if isinstance(orig_val, float):
170                new_val = f"{orig_val:.6f}"
171            elif orig_val is None:
172                new_val = ""
173            else:
174                new_val = orig_val
175            row[idx] = new_val
176
177        with open(filename, "a") as fd:
178            writer = csv.writer(fd, lineterminator="\n")
179            writer.writerow(row)
180
181    @staticmethod
182    def register_table(name, column_names):
183        table = MetricTable(name, column_names)
184        REGISTERED_METRIC_TABLES[name] = table
185
186
187MetricTable.register_table(
188    "slow_fusion",
189    [
190        "kernel1_path",
191        "kernel1_latency",
192        "kernel2_path",
193        "kernel2_latency",
194        "fused_kernel_path",
195        "fused_kernel_latency",
196        "slow_down_ratio",
197    ],
198)
199
200# track the fusion statistics for each graph
201MetricTable.register_table(
202    "graph_stats",
203    [
204        "graph_id",
205        "num_nodes_before_fusion",
206        "num_nodes_after_fusion",
207    ],
208)
209
210# track the perf difference between persistent reduction and non-persistent
211# reductions
212MetricTable.register_table(
213    "persistent_red_perf",
214    [
215        "kernel1_name",
216        "kernel2_name",
217        "kernel1_latency",
218        "kernel2_latency",
219        "size_hints",
220        "reduction_hint",
221        "speedup",
222    ],
223)
224
225# Log the fusion failures due to indexing mismatch
226MetricTable.register_table(
227    "fusion_failure_due_to_indexing_mismatch",
228    [
229        "pre_grad_graph_id",
230        "post_grad_graph_id",
231        "node1_name",
232        "node2_name",
233        "node1_debug_str",
234        "node2_debug_str",
235        "common_buffer_names",
236        "failure_reason",
237    ],
238)
239
240# Log metadata for pointwise/reduction kernels. E.g., model name, kernel path, numel, rnumel, reduction hint
241MetricTable.register_table(
242    "kernel_metadata",
243    [
244        "kernel_name",
245        "kernel_path",
246        "kernel_category",  # pointwise/reduction/foreach etc.
247        "size_hints",
248        "reduction_hint",
249        "line_of_code",
250        "num_load",
251        "num_store",
252        "num_for_loop",
253        "num_atomic_add",
254        "num_args",
255        # xyz numel can be different to size_hints since size_hints are rounded
256        # up to the nearest power of 2.
257        # Inductor kernel will burn in the xyz numel in kernel code for static
258        # shape kernels.
259        # Logging them will be helpful to find unaligned shape for reduction
260        "xnumel",
261        "ynumel",
262        "rnumel",
263        "kernel_args_num_gb",
264    ],
265)
266
267
268def _parse_kernel_fn_code(kernel_module_code):
269    """
270    The kernel_module_code is the python module that contains kernel function code.
271    kernel function is the proper triton kernel function annotated with
272    @triton.jit
273    """
274    from .codecache import PyCodeCache
275    from .wrapper_benchmark import get_triton_kernel
276
277    mod = PyCodeCache.load(kernel_module_code)
278    kernel = get_triton_kernel(mod)
279    # kernel is a CachingAutotune; kernel.fn is the JITFunction;
280    # kernel.fn.fn is the function being decorate by triton.jit
281    return inspect.getsource(kernel.fn.fn)
282
283
284def _parse_kernel_line_of_code(proper_kernel_fn_code):
285    """
286    Return the line of code for the kernel excluding the decorators.
287    """
288    return len(proper_kernel_fn_code.splitlines())
289
290
291def _parse_size_hints(kernel_module_code, kernel_category):
292    if kernel_category == "foreach":
293        # foreach kernel does not have size_hints
294        return None
295    m = re.search(r"size_hints=(\[[0-9, ]*\]),", kernel_module_code)
296    assert m, "size_hints missing!"
297    return m.group(1)
298
299
300def _parse_reduction_hint(kernel_category, kernel_module_code):
301    if kernel_category not in ("reduction", "persistent_reduction"):
302        return None
303    m = re.search(r"reduction_hint=ReductionHint\.(\w*),", kernel_module_code)
304    assert m, "reduction_hint not found in kernel source code!"
305    return m.group(1)
306
307
308def _count_pattern(proper_kernel_fn_code, pattern):
309    return proper_kernel_fn_code.count(pattern)
310
311
312def _count_args(proper_kernel_fn_code):
313    def_line = proper_kernel_fn_code.splitlines()[0]
314    assert def_line.startswith("def ")
315    start_idx = def_line.index("(")
316    end_idx = def_line.index("):")
317    decl_csv = def_line[start_idx + 1 : end_idx]
318    comps = decl_csv.split(",")
319    return len(comps)
320
321
322def _parse_proper_kernel_fn_code(kernel_fn_code):
323    """
324    Skip decorators.
325    """
326    start_pos = kernel_fn_code.index("def ")
327    return kernel_fn_code[start_pos:]
328
329
330def _parse_numel(proper_kernel_fn_code, numel_arg_name):
331    m = re.search(f"{numel_arg_name} = ([\\d]+)", proper_kernel_fn_code)
332    if m:
333        return int(m.group(1))
334    else:
335        return None
336
337
338def _parse_kernel_args_num_gb(kernel_fn_code, kernel_category):
339    """
340    inductor meta looks like:
341        inductor_meta={... 'mutated_arg_names': [], 'no_x_dim': False, 'kernel_num_gb': 2.0},
342    """
343    m = re.search(r".kernel_num_gb.:\s*([0-9.]+)", kernel_fn_code)
344    if m:
345        return float(m.group(1))
346    else:
347        """
348        There are a few cases that kernel_num_gdb field can be missing:
349        1. the field will be missing if config.benchmark_kernel and
350           config.profile_bandwidth are false
351        2. even if config.benchmark_kernel or config.profile_bandwidth is true.
352           foreach kernel does not have kernel_num_gb field in the metadata
353        """
354        return None
355
356
357def log_kernel_metadata(kernel_name, kernel_path, kernel_module_code):
358    """
359    An utility to log kernel metadata. We may parse metadata from kernel source code here.
360
361    It's fine to parse the generated kernel code here since the logging is
362    disabled by default. It would hurt compilation time.
363    """
364    from .wrapper_benchmark import get_kernel_category_by_source_code
365
366    kernel_category = get_kernel_category_by_source_code(kernel_module_code)
367    reduction_hint = _parse_reduction_hint(kernel_category, kernel_module_code)
368    size_hints = _parse_size_hints(kernel_module_code, kernel_category)
369    kernel_fn_code = _parse_kernel_fn_code(kernel_module_code)
370
371    proper_kernel_fn_code = _parse_proper_kernel_fn_code(kernel_fn_code)
372
373    # the line of code excluding the decortors
374    kernel_line_of_code = _parse_kernel_line_of_code(proper_kernel_fn_code)
375
376    get_metric_table("kernel_metadata").add_row(
377        lambda: {
378            "kernel_name": kernel_name,
379            "kernel_path": kernel_path,
380            "kernel_category": kernel_category,
381            "size_hints": size_hints,
382            "reduction_hint": reduction_hint,
383            "line_of_code": kernel_line_of_code,
384            "num_load": _count_pattern(proper_kernel_fn_code, "tl.load"),
385            "num_store": _count_pattern(proper_kernel_fn_code, "tl.store"),
386            "num_for_loop": _count_pattern(proper_kernel_fn_code, "for "),
387            "num_atomic_add": _count_pattern(proper_kernel_fn_code, "tl.atomic_add"),
388            "num_args": _count_args(proper_kernel_fn_code),
389            "xnumel": _parse_numel(proper_kernel_fn_code, "xnumel"),
390            "ynumel": _parse_numel(proper_kernel_fn_code, "ynumel"),
391            "rnumel": _parse_numel(proper_kernel_fn_code, "rnumel"),
392            "kernel_args_num_gb": _parse_kernel_args_num_gb(
393                kernel_fn_code, kernel_category
394            ),
395        }
396    )
397
398
399def purge_old_log_files():
400    """
401    Purge the old log file at the beginning when the benchmark script runs.
402    Should do it in the parent process rather than the child processes running
403    each individual model.
404    """
405    for name, table in REGISTERED_METRIC_TABLES.items():
406        if name in enabled_metric_tables():
407            filename = table.output_filename()
408            if os.path.exists(filename):
409                os.unlink(filename)
410
411            table.write_header()
412
413
414@lru_cache
415def enabled_metric_tables() -> Set[str]:
416    config_str = config.enabled_metric_tables
417
418    enabled = set()
419    for name in config_str.split(","):
420        name = name.strip()
421        if not name:
422            continue
423        assert (
424            name in REGISTERED_METRIC_TABLES
425        ), f"Metric table name {name} is not registered"
426        enabled.add(name)
427    return enabled
428
429
430def is_metric_table_enabled(name):
431    return name in enabled_metric_tables()
432
433
434def get_metric_table(name):
435    assert name in REGISTERED_METRIC_TABLES, f"Metric table {name} is not defined"
436    return REGISTERED_METRIC_TABLES[name]
437