1# mypy: allow-untyped-defs 2"""Display class to aggregate and print the results of many measurements.""" 3import collections 4import enum 5import itertools as it 6from typing import DefaultDict, List, Optional, Tuple 7 8from torch.utils.benchmark.utils import common 9from torch import tensor as _tensor 10import operator 11 12__all__ = ["Colorize", "Compare"] 13 14BEST = "\033[92m" 15GOOD = "\033[34m" 16BAD = "\033[2m\033[91m" 17VERY_BAD = "\033[31m" 18BOLD = "\033[1m" 19TERMINATE = "\033[0m" 20 21 22class Colorize(enum.Enum): 23 NONE = "none" 24 COLUMNWISE = "columnwise" 25 ROWWISE = "rowwise" 26 27 28# Classes to separate internal bookkeeping from what is rendered. 29class _Column: 30 def __init__( 31 self, 32 grouped_results: List[Tuple[Optional[common.Measurement], ...]], 33 time_scale: float, 34 time_unit: str, 35 trim_significant_figures: bool, 36 highlight_warnings: bool, 37 ): 38 self._grouped_results = grouped_results 39 self._flat_results = list(it.chain(*grouped_results)) 40 self._time_scale = time_scale 41 self._time_unit = time_unit 42 self._trim_significant_figures = trim_significant_figures 43 self._highlight_warnings = ( 44 highlight_warnings 45 and any(r.has_warnings for r in self._flat_results if r) 46 ) 47 leading_digits = [ 48 int(_tensor(r.median / self._time_scale).log10().ceil()) if r else None 49 for r in self._flat_results 50 ] 51 unit_digits = max(d for d in leading_digits if d is not None) 52 decimal_digits = min( 53 max(m.significant_figures - digits, 0) 54 for digits, m in zip(leading_digits, self._flat_results) 55 if (m is not None) and (digits is not None) 56 ) if self._trim_significant_figures else 1 57 length = unit_digits + decimal_digits + (1 if decimal_digits else 0) 58 self._template = f"{{:>{length}.{decimal_digits}f}}{{:>{7 if self._highlight_warnings else 0}}}" 59 60 def get_results_for(self, group): 61 return self._grouped_results[group] 62 63 def num_to_str(self, value: Optional[float], estimated_sigfigs: int, spread: Optional[float]): 64 if value is None: 65 return " " * len(self.num_to_str(1, estimated_sigfigs, None)) 66 67 if self._trim_significant_figures: 68 value = common.trim_sigfig(value, estimated_sigfigs) 69 70 return self._template.format( 71 value, 72 f" (! {spread * 100:.0f}%)" if self._highlight_warnings and spread is not None else "") 73 74 75def optional_min(seq): 76 l = list(seq) 77 return None if len(l) == 0 else min(l) 78 79 80class _Row: 81 def __init__(self, results, row_group, render_env, env_str_len, 82 row_name_str_len, time_scale, colorize, num_threads=None): 83 super().__init__() 84 self._results = results 85 self._row_group = row_group 86 self._render_env = render_env 87 self._env_str_len = env_str_len 88 self._row_name_str_len = row_name_str_len 89 self._time_scale = time_scale 90 self._colorize = colorize 91 self._columns: Tuple[_Column, ...] = () 92 self._num_threads = num_threads 93 94 def register_columns(self, columns: Tuple[_Column, ...]): 95 self._columns = columns 96 97 def as_column_strings(self): 98 concrete_results = [r for r in self._results if r is not None] 99 env = f"({concrete_results[0].env})" if self._render_env else "" 100 env = env.ljust(self._env_str_len + 4) 101 output = [" " + env + concrete_results[0].as_row_name] 102 for m, col in zip(self._results, self._columns or ()): 103 if m is None: 104 output.append(col.num_to_str(None, 1, None)) 105 else: 106 output.append(col.num_to_str( 107 m.median / self._time_scale, 108 m.significant_figures, 109 m.iqr / m.median if m.has_warnings else None 110 )) 111 return output 112 113 @staticmethod 114 def color_segment(segment, value, best_value): 115 if value <= best_value * 1.01 or value <= best_value + 100e-9: 116 return BEST + BOLD + segment + TERMINATE * 2 117 if value <= best_value * 1.1: 118 return GOOD + BOLD + segment + TERMINATE * 2 119 if value >= best_value * 5: 120 return VERY_BAD + BOLD + segment + TERMINATE * 2 121 if value >= best_value * 2: 122 return BAD + segment + TERMINATE * 2 123 124 return segment 125 126 def row_separator(self, overall_width): 127 return ( 128 [f"{self._num_threads} threads: ".ljust(overall_width, "-")] 129 if self._num_threads is not None else [] 130 ) 131 132 def finalize_column_strings(self, column_strings, col_widths): 133 best_values = [-1 for _ in column_strings] 134 if self._colorize == Colorize.ROWWISE: 135 row_min = min(r.median for r in self._results if r is not None) 136 best_values = [row_min for _ in column_strings] 137 elif self._colorize == Colorize.COLUMNWISE: 138 best_values = [ 139 optional_min(r.median for r in column.get_results_for(self._row_group) if r is not None) 140 for column in (self._columns or ()) 141 ] 142 143 row_contents = [column_strings[0].ljust(col_widths[0])] 144 for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values): 145 col_str = col_str.center(width) 146 if self._colorize != Colorize.NONE and result is not None and best_value is not None: 147 col_str = self.color_segment(col_str, result.median, best_value) 148 row_contents.append(col_str) 149 return row_contents 150 151 152class Table: 153 def __init__( 154 self, 155 results: List[common.Measurement], 156 colorize: Colorize, 157 trim_significant_figures: bool, 158 highlight_warnings: bool 159 ): 160 assert len({r.label for r in results}) == 1 161 162 self.results = results 163 self._colorize = colorize 164 self._trim_significant_figures = trim_significant_figures 165 self._highlight_warnings = highlight_warnings 166 self.label = results[0].label 167 self.time_unit, self.time_scale = common.select_unit( 168 min(r.median for r in results) 169 ) 170 171 self.row_keys = common.ordered_unique([self.row_fn(i) for i in results]) 172 self.row_keys.sort(key=operator.itemgetter(slice(2))) # preserve stmt order 173 self.column_keys = common.ordered_unique([self.col_fn(i) for i in results]) 174 self.rows, self.columns = self.populate_rows_and_columns() 175 176 @staticmethod 177 def row_fn(m: common.Measurement) -> Tuple[int, Optional[str], str]: 178 return m.num_threads, m.env, m.as_row_name 179 180 @staticmethod 181 def col_fn(m: common.Measurement) -> Optional[str]: 182 return m.description 183 184 def populate_rows_and_columns(self) -> Tuple[Tuple[_Row, ...], Tuple[_Column, ...]]: 185 rows: List[_Row] = [] 186 columns: List[_Column] = [] 187 ordered_results: List[List[Optional[common.Measurement]]] = [ 188 [None for _ in self.column_keys] 189 for _ in self.row_keys 190 ] 191 row_position = {key: i for i, key in enumerate(self.row_keys)} 192 col_position = {key: i for i, key in enumerate(self.column_keys)} 193 for r in self.results: 194 i = row_position[self.row_fn(r)] 195 j = col_position[self.col_fn(r)] 196 ordered_results[i][j] = r 197 198 unique_envs = {r.env for r in self.results} 199 render_env = len(unique_envs) > 1 200 env_str_len = max(len(i) for i in unique_envs) if render_env else 0 201 202 row_name_str_len = max(len(r.as_row_name) for r in self.results) 203 204 prior_num_threads = -1 205 prior_env = "" 206 row_group = -1 207 rows_by_group: List[List[List[Optional[common.Measurement]]]] = [] 208 for (num_threads, env, _), row in zip(self.row_keys, ordered_results): 209 thread_transition = (num_threads != prior_num_threads) 210 if thread_transition: 211 prior_num_threads = num_threads 212 prior_env = "" 213 row_group += 1 214 rows_by_group.append([]) 215 rows.append( 216 _Row( 217 results=row, 218 row_group=row_group, 219 render_env=(render_env and env != prior_env), 220 env_str_len=env_str_len, 221 row_name_str_len=row_name_str_len, 222 time_scale=self.time_scale, 223 colorize=self._colorize, 224 num_threads=num_threads if thread_transition else None, 225 ) 226 ) 227 rows_by_group[-1].append(row) 228 prior_env = env 229 230 for i in range(len(self.column_keys)): 231 grouped_results = [tuple(row[i] for row in g) for g in rows_by_group] 232 column = _Column( 233 grouped_results=grouped_results, 234 time_scale=self.time_scale, 235 time_unit=self.time_unit, 236 trim_significant_figures=self._trim_significant_figures, 237 highlight_warnings=self._highlight_warnings,) 238 columns.append(column) 239 240 rows_tuple, columns_tuple = tuple(rows), tuple(columns) 241 for ri in rows_tuple: 242 ri.register_columns(columns_tuple) 243 return rows_tuple, columns_tuple 244 245 def render(self) -> str: 246 string_rows = [[""] + self.column_keys] 247 for r in self.rows: 248 string_rows.append(r.as_column_strings()) 249 num_cols = max(len(i) for i in string_rows) 250 for sr in string_rows: 251 sr.extend(["" for _ in range(num_cols - len(sr))]) 252 253 col_widths = [max(len(j) for j in i) for i in zip(*string_rows)] 254 finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))] 255 overall_width = len(finalized_columns[0]) 256 for string_row, row in zip(string_rows[1:], self.rows): 257 finalized_columns.extend(row.row_separator(overall_width)) 258 finalized_columns.append(" | ".join(row.finalize_column_strings(string_row, col_widths))) 259 260 newline = "\n" 261 has_warnings = self._highlight_warnings and any(ri.has_warnings for ri in self.results) 262 return f""" 263[{(' ' + (self.label or '') + ' ').center(overall_width - 2, '-')}] 264{newline.join(finalized_columns)} 265 266Times are in {common.unit_to_english(self.time_unit)}s ({self.time_unit}). 267{'(! XX%) Measurement has high variance, where XX is the IQR / median * 100.' + newline if has_warnings else ""}"""[1:] 268 269 270class Compare: 271 """Helper class for displaying the results of many measurements in a 272 formatted table. 273 274 The table format is based on the information fields provided in 275 :class:`torch.utils.benchmark.Timer` (`description`, `label`, `sub_label`, 276 `num_threads`, etc). 277 278 The table can be directly printed using :meth:`print` or casted as a `str`. 279 280 For a full tutorial on how to use this class, see: 281 https://pytorch.org/tutorials/recipes/recipes/benchmark.html 282 283 Args: 284 results: List of Measurment to display. 285 """ 286 def __init__(self, results: List[common.Measurement]): 287 self._results: List[common.Measurement] = [] 288 self.extend_results(results) 289 self._trim_significant_figures = False 290 self._colorize = Colorize.NONE 291 self._highlight_warnings = False 292 293 def __str__(self): 294 return "\n".join(self._render()) 295 296 def extend_results(self, results): 297 """Append results to already stored ones. 298 299 All added results must be instances of ``Measurement``. 300 """ 301 for r in results: 302 if not isinstance(r, common.Measurement): 303 raise ValueError( 304 "Expected an instance of `Measurement`, " f"got {type(r)} instead." 305 ) 306 self._results.extend(results) 307 308 def trim_significant_figures(self): 309 """Enables trimming of significant figures when building the formatted table.""" 310 self._trim_significant_figures = True 311 312 def colorize(self, rowwise=False): 313 """Colorize formatted table. 314 315 Colorize columnwise by default. 316 """ 317 self._colorize = Colorize.ROWWISE if rowwise else Colorize.COLUMNWISE 318 319 def highlight_warnings(self): 320 """Enables warning highlighting when building formatted table.""" 321 self._highlight_warnings = True 322 323 def print(self): 324 """Print formatted table""" 325 print(str(self)) 326 327 def _render(self): 328 results = common.Measurement.merge(self._results) 329 grouped_results = self._group_by_label(results) 330 output = [] 331 for group in grouped_results.values(): 332 output.append(self._layout(group)) 333 return output 334 335 def _group_by_label(self, results: List[common.Measurement]): 336 grouped_results: DefaultDict[str, List[common.Measurement]] = collections.defaultdict(list) 337 for r in results: 338 grouped_results[r.label].append(r) 339 return grouped_results 340 341 def _layout(self, results: List[common.Measurement]): 342 table = Table( 343 results, 344 self._colorize, 345 self._trim_significant_figures, 346 self._highlight_warnings 347 ) 348 return table.render() 349