xref: /aosp_15_r20/external/pytorch/torch/utils/benchmark/utils/compare.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Display class to aggregate and print the results of many measurements."""
3import collections
4import enum
5import itertools as it
6from typing import DefaultDict, List, Optional, Tuple
7
8from torch.utils.benchmark.utils import common
9from torch import tensor as _tensor
10import operator
11
12__all__ = ["Colorize", "Compare"]
13
14BEST = "\033[92m"
15GOOD = "\033[34m"
16BAD = "\033[2m\033[91m"
17VERY_BAD = "\033[31m"
18BOLD = "\033[1m"
19TERMINATE = "\033[0m"
20
21
22class Colorize(enum.Enum):
23    NONE = "none"
24    COLUMNWISE = "columnwise"
25    ROWWISE = "rowwise"
26
27
28# Classes to separate internal bookkeeping from what is rendered.
29class _Column:
30    def __init__(
31        self,
32        grouped_results: List[Tuple[Optional[common.Measurement], ...]],
33        time_scale: float,
34        time_unit: str,
35        trim_significant_figures: bool,
36        highlight_warnings: bool,
37    ):
38        self._grouped_results = grouped_results
39        self._flat_results = list(it.chain(*grouped_results))
40        self._time_scale = time_scale
41        self._time_unit = time_unit
42        self._trim_significant_figures = trim_significant_figures
43        self._highlight_warnings = (
44            highlight_warnings
45            and any(r.has_warnings for r in self._flat_results if r)
46        )
47        leading_digits = [
48            int(_tensor(r.median / self._time_scale).log10().ceil()) if r else None
49            for r in self._flat_results
50        ]
51        unit_digits = max(d for d in leading_digits if d is not None)
52        decimal_digits = min(
53            max(m.significant_figures - digits, 0)
54            for digits, m in zip(leading_digits, self._flat_results)
55            if (m is not None) and (digits is not None)
56        ) if self._trim_significant_figures else 1
57        length = unit_digits + decimal_digits + (1 if decimal_digits else 0)
58        self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}"
59
60    def get_results_for(self, group):
61        return self._grouped_results[group]
62
63    def num_to_str(self, value: Optional[float], estimated_sigfigs: int, spread: Optional[float]):
64        if value is None:
65            return " " * len(self.num_to_str(1, estimated_sigfigs, None))
66
67        if self._trim_significant_figures:
68            value = common.trim_sigfig(value, estimated_sigfigs)
69
70        return self._template.format(
71            value,
72            f" (! {spread * 100:.0f}%)" if self._highlight_warnings and spread is not None else "")
73
74
75def optional_min(seq):
76    l = list(seq)
77    return None if len(l) == 0 else min(l)
78
79
80class _Row:
81    def __init__(self, results, row_group, render_env, env_str_len,
82                 row_name_str_len, time_scale, colorize, num_threads=None):
83        super().__init__()
84        self._results = results
85        self._row_group = row_group
86        self._render_env = render_env
87        self._env_str_len = env_str_len
88        self._row_name_str_len = row_name_str_len
89        self._time_scale = time_scale
90        self._colorize = colorize
91        self._columns: Tuple[_Column, ...] = ()
92        self._num_threads = num_threads
93
94    def register_columns(self, columns: Tuple[_Column, ...]):
95        self._columns = columns
96
97    def as_column_strings(self):
98        concrete_results = [r for r in self._results if r is not None]
99        env = f"({concrete_results[0].env})" if self._render_env else ""
100        env = env.ljust(self._env_str_len + 4)
101        output = ["  " + env + concrete_results[0].as_row_name]
102        for m, col in zip(self._results, self._columns or ()):
103            if m is None:
104                output.append(col.num_to_str(None, 1, None))
105            else:
106                output.append(col.num_to_str(
107                    m.median / self._time_scale,
108                    m.significant_figures,
109                    m.iqr / m.median if m.has_warnings else None
110                ))
111        return output
112
113    @staticmethod
114    def color_segment(segment, value, best_value):
115        if value <= best_value * 1.01 or value <= best_value + 100e-9:
116            return BEST + BOLD + segment + TERMINATE * 2
117        if value <= best_value * 1.1:
118            return GOOD + BOLD + segment + TERMINATE * 2
119        if value >= best_value * 5:
120            return VERY_BAD + BOLD + segment + TERMINATE * 2
121        if value >= best_value * 2:
122            return BAD + segment + TERMINATE * 2
123
124        return segment
125
126    def row_separator(self, overall_width):
127        return (
128            [f"{self._num_threads} threads: ".ljust(overall_width, "-")]
129            if self._num_threads is not None else []
130        )
131
132    def finalize_column_strings(self, column_strings, col_widths):
133        best_values = [-1 for _ in column_strings]
134        if self._colorize == Colorize.ROWWISE:
135            row_min = min(r.median for r in self._results if r is not None)
136            best_values = [row_min for _ in column_strings]
137        elif self._colorize == Colorize.COLUMNWISE:
138            best_values = [
139                optional_min(r.median for r in column.get_results_for(self._row_group) if r is not None)
140                for column in (self._columns or ())
141            ]
142
143        row_contents = [column_strings[0].ljust(col_widths[0])]
144        for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values):
145            col_str = col_str.center(width)
146            if self._colorize != Colorize.NONE and result is not None and best_value is not None:
147                col_str = self.color_segment(col_str, result.median, best_value)
148            row_contents.append(col_str)
149        return row_contents
150
151
152class Table:
153    def __init__(
154            self,
155            results: List[common.Measurement],
156            colorize: Colorize,
157            trim_significant_figures: bool,
158            highlight_warnings: bool
159    ):
160        assert len({r.label for r in results}) == 1
161
162        self.results = results
163        self._colorize = colorize
164        self._trim_significant_figures = trim_significant_figures
165        self._highlight_warnings = highlight_warnings
166        self.label = results[0].label
167        self.time_unit, self.time_scale = common.select_unit(
168            min(r.median for r in results)
169        )
170
171        self.row_keys = common.ordered_unique([self.row_fn(i) for i in results])
172        self.row_keys.sort(key=operator.itemgetter(slice(2)))  # preserve stmt order
173        self.column_keys = common.ordered_unique([self.col_fn(i) for i in results])
174        self.rows, self.columns = self.populate_rows_and_columns()
175
176    @staticmethod
177    def row_fn(m: common.Measurement) -> Tuple[int, Optional[str], str]:
178        return m.num_threads, m.env, m.as_row_name
179
180    @staticmethod
181    def col_fn(m: common.Measurement) -> Optional[str]:
182        return m.description
183
184    def populate_rows_and_columns(self) -> Tuple[Tuple[_Row, ...], Tuple[_Column, ...]]:
185        rows: List[_Row] = []
186        columns: List[_Column] = []
187        ordered_results: List[List[Optional[common.Measurement]]] = [
188            [None for _ in self.column_keys]
189            for _ in self.row_keys
190        ]
191        row_position = {key: i for i, key in enumerate(self.row_keys)}
192        col_position = {key: i for i, key in enumerate(self.column_keys)}
193        for r in self.results:
194            i = row_position[self.row_fn(r)]
195            j = col_position[self.col_fn(r)]
196            ordered_results[i][j] = r
197
198        unique_envs = {r.env for r in self.results}
199        render_env = len(unique_envs) > 1
200        env_str_len = max(len(i) for i in unique_envs) if render_env else 0
201
202        row_name_str_len = max(len(r.as_row_name) for r in self.results)
203
204        prior_num_threads = -1
205        prior_env = ""
206        row_group = -1
207        rows_by_group: List[List[List[Optional[common.Measurement]]]] = []
208        for (num_threads, env, _), row in zip(self.row_keys, ordered_results):
209            thread_transition = (num_threads != prior_num_threads)
210            if thread_transition:
211                prior_num_threads = num_threads
212                prior_env = ""
213                row_group += 1
214                rows_by_group.append([])
215            rows.append(
216                _Row(
217                    results=row,
218                    row_group=row_group,
219                    render_env=(render_env and env != prior_env),
220                    env_str_len=env_str_len,
221                    row_name_str_len=row_name_str_len,
222                    time_scale=self.time_scale,
223                    colorize=self._colorize,
224                    num_threads=num_threads if thread_transition else None,
225                )
226            )
227            rows_by_group[-1].append(row)
228            prior_env = env
229
230        for i in range(len(self.column_keys)):
231            grouped_results = [tuple(row[i] for row in g) for g in rows_by_group]
232            column = _Column(
233                grouped_results=grouped_results,
234                time_scale=self.time_scale,
235                time_unit=self.time_unit,
236                trim_significant_figures=self._trim_significant_figures,
237                highlight_warnings=self._highlight_warnings,)
238            columns.append(column)
239
240        rows_tuple, columns_tuple = tuple(rows), tuple(columns)
241        for ri in rows_tuple:
242            ri.register_columns(columns_tuple)
243        return rows_tuple, columns_tuple
244
245    def render(self) -> str:
246        string_rows = [[""] + self.column_keys]
247        for r in self.rows:
248            string_rows.append(r.as_column_strings())
249        num_cols = max(len(i) for i in string_rows)
250        for sr in string_rows:
251            sr.extend(["" for _ in range(num_cols - len(sr))])
252
253        col_widths = [max(len(j) for j in i) for i in zip(*string_rows)]
254        finalized_columns = ["  |  ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))]
255        overall_width = len(finalized_columns[0])
256        for string_row, row in zip(string_rows[1:], self.rows):
257            finalized_columns.extend(row.row_separator(overall_width))
258            finalized_columns.append("  |  ".join(row.finalize_column_strings(string_row, col_widths)))
259
260        newline = "\n"
261        has_warnings = self._highlight_warnings and any(ri.has_warnings for ri in self.results)
262        return f"""
263[{(' ' + (self.label or '') + ' ').center(overall_width - 2, '-')}]
264{newline.join(finalized_columns)}
265
266Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}).
267{'(! XX%) Measurement has high variance, where XX is the IQR / median * 100.' + newline if has_warnings else ""}"""[1:]
268
269
270class Compare:
271    """Helper class for displaying the results of many measurements in a
272    formatted table.
273
274    The table format is based on the information fields provided in
275    :class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`,
276    `num_threads`, etc).
277
278    The table can be directly printed using :meth:`print` or casted as a `str`.
279
280    For a full tutorial on how to use this class, see:
281    https://pytorch.org/tutorials/recipes/recipes/benchmark.html
282
283    Args:
284        results: List of Measurment to display.
285    """
286    def __init__(self, results: List[common.Measurement]):
287        self._results: List[common.Measurement] = []
288        self.extend_results(results)
289        self._trim_significant_figures = False
290        self._colorize = Colorize.NONE
291        self._highlight_warnings = False
292
293    def __str__(self):
294        return "\n".join(self._render())
295
296    def extend_results(self, results):
297        """Append results to already stored ones.
298
299        All added results must be instances of ``Measurement``.
300        """
301        for r in results:
302            if not isinstance(r, common.Measurement):
303                raise ValueError(
304                    "Expected an instance of `Measurement`, " f"got {type(r)} instead."
305                )
306        self._results.extend(results)
307
308    def trim_significant_figures(self):
309        """Enables trimming of significant figures when building the formatted table."""
310        self._trim_significant_figures = True
311
312    def colorize(self, rowwise=False):
313        """Colorize formatted table.
314
315        Colorize columnwise by default.
316        """
317        self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE
318
319    def highlight_warnings(self):
320        """Enables warning highlighting when building formatted table."""
321        self._highlight_warnings = True
322
323    def print(self):
324        """Print formatted table"""
325        print(str(self))
326
327    def _render(self):
328        results = common.Measurement.merge(self._results)
329        grouped_results = self._group_by_label(results)
330        output = []
331        for group in grouped_results.values():
332            output.append(self._layout(group))
333        return output
334
335    def _group_by_label(self, results: List[common.Measurement]):
336        grouped_results: DefaultDict[str, List[common.Measurement]] = collections.defaultdict(list)
337        for r in results:
338            grouped_results[r.label].append(r)
339        return grouped_results
340
341    def _layout(self, results: List[common.Measurement]):
342        table = Table(
343            results,
344            self._colorize,
345            self._trim_significant_figures,
346            self._highlight_warnings
347        )
348        return table.render()
349