1"""Base shared classes and utilities.""" 2 3import collections 4import contextlib 5import dataclasses 6import os 7import shutil 8import tempfile 9import textwrap 10import time 11from typing import cast, Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple 12import uuid 13 14import torch 15 16 17__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"] 18 19 20_MAX_SIGNIFICANT_FIGURES = 4 21_MIN_CONFIDENCE_INTERVAL = 25e-9 # 25 ns 22 23# Measurement will include a warning if the distribution is suspect. All 24# runs are expected to have some variation; these parameters set the 25# thresholds. 26_IQR_WARN_THRESHOLD = 0.1 27_IQR_GROSS_WARN_THRESHOLD = 0.25 28 29 30@dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True) 31class TaskSpec: 32 """Container for information used to define a Timer. (except globals)""" 33 stmt: str 34 setup: str 35 global_setup: str = "" 36 label: Optional[str] = None 37 sub_label: Optional[str] = None 38 description: Optional[str] = None 39 env: Optional[str] = None 40 num_threads: int = 1 41 42 @property 43 def title(self) -> str: 44 """Best effort attempt at a string label for the measurement.""" 45 if self.label is not None: 46 return self.label + (f": {self.sub_label}" if self.sub_label else "") 47 elif "\n" not in self.stmt: 48 return self.stmt + (f": {self.sub_label}" if self.sub_label else "") 49 return ( 50 f"stmt:{f' ({self.sub_label})' if self.sub_label else ''}\n" 51 f"{textwrap.indent(self.stmt, ' ')}" 52 ) 53 54 def setup_str(self) -> str: 55 return ( 56 "" if (self.setup == "pass" or not self.setup) 57 else f"setup:\n{textwrap.indent(self.setup, ' ')}" if "\n" in self.setup 58 else f"setup: {self.setup}" 59 ) 60 61 def summarize(self) -> str: 62 """Build TaskSpec portion of repr string for other containers.""" 63 sections = [ 64 self.title, 65 self.description or "", 66 self.setup_str(), 67 ] 68 return "\n".join([f"{i}\n" if "\n" in i else i for i in sections if i]) 69 70_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(TaskSpec)) 71 72 73@dataclasses.dataclass(init=True, repr=False) 74class Measurement: 75 """The result of a Timer measurement. 76 77 This class stores one or more measurements of a given statement. It is 78 serializable and provides several convenience methods 79 (including a detailed __repr__) for downstream consumers. 80 """ 81 number_per_run: int 82 raw_times: List[float] 83 task_spec: TaskSpec 84 metadata: Optional[Dict[Any, Any]] = None # Reserved for user payloads. 85 86 def __post_init__(self) -> None: 87 self._sorted_times: Tuple[float, ...] = () 88 self._warnings: Tuple[str, ...] = () 89 self._median: float = -1.0 90 self._mean: float = -1.0 91 self._p25: float = -1.0 92 self._p75: float = -1.0 93 94 def __getattr__(self, name: str) -> Any: 95 # Forward TaskSpec fields for convenience. 96 if name in _TASKSPEC_FIELDS: 97 return getattr(self.task_spec, name) 98 return super().__getattribute__(name) 99 100 # ========================================================================= 101 # == Convenience methods for statistics =================================== 102 # ========================================================================= 103 # 104 # These methods use raw time divided by number_per_run; this is an 105 # extrapolation and hides the fact that different number_per_run will 106 # result in different amortization of overheads, however if Timer has 107 # selected an appropriate number_per_run then this is a non-issue, and 108 # forcing users to handle that division would result in a poor experience. 109 @property 110 def times(self) -> List[float]: 111 return [t / self.number_per_run for t in self.raw_times] 112 113 @property 114 def median(self) -> float: 115 self._lazy_init() 116 return self._median 117 118 @property 119 def mean(self) -> float: 120 self._lazy_init() 121 return self._mean 122 123 @property 124 def iqr(self) -> float: 125 self._lazy_init() 126 return self._p75 - self._p25 127 128 @property 129 def significant_figures(self) -> int: 130 """Approximate significant figure estimate. 131 132 This property is intended to give a convenient way to estimate the 133 precision of a measurement. It only uses the interquartile region to 134 estimate statistics to try to mitigate skew from the tails, and 135 uses a static z value of 1.645 since it is not expected to be used 136 for small values of `n`, so z can approximate `t`. 137 138 The significant figure estimation used in conjunction with the 139 `trim_sigfig` method to provide a more human interpretable data 140 summary. __repr__ does not use this method; it simply displays raw 141 values. Significant figure estimation is intended for `Compare`. 142 """ 143 self._lazy_init() 144 n_total = len(self._sorted_times) 145 lower_bound = int(n_total // 4) 146 upper_bound = int(torch.tensor(3 * n_total / 4).ceil()) 147 interquartile_points: Tuple[float, ...] = self._sorted_times[lower_bound:upper_bound] 148 std = torch.tensor(interquartile_points).std(unbiased=False).item() 149 sqrt_n = torch.tensor(len(interquartile_points)).sqrt().item() 150 151 # Rough estimates. These are by no means statistically rigorous. 152 confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL) 153 relative_ci = torch.tensor(self._median / confidence_interval).log10().item() 154 num_significant_figures = int(torch.tensor(relative_ci).floor()) 155 return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES) 156 157 @property 158 def has_warnings(self) -> bool: 159 self._lazy_init() 160 return bool(self._warnings) 161 162 def _lazy_init(self) -> None: 163 if self.raw_times and not self._sorted_times: 164 self._sorted_times = tuple(sorted(self.times)) 165 _sorted_times = torch.tensor(self._sorted_times, dtype=torch.float64) 166 self._median = _sorted_times.quantile(.5).item() 167 self._mean = _sorted_times.mean().item() 168 self._p25 = _sorted_times.quantile(.25).item() 169 self._p75 = _sorted_times.quantile(.75).item() 170 171 def add_warning(msg: str) -> None: 172 rel_iqr = self.iqr / self.median * 100 173 self._warnings += ( 174 f" WARNING: Interquartile range is {rel_iqr:.1f}% " 175 f"of the median measurement.\n {msg}", 176 ) 177 178 if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD): 179 add_warning("This suggests significant environmental influence.") 180 elif not self.meets_confidence(_IQR_WARN_THRESHOLD): 181 add_warning("This could indicate system fluctuation.") 182 183 184 def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool: 185 return self.iqr / self.median < threshold 186 187 @property 188 def title(self) -> str: 189 return self.task_spec.title 190 191 @property 192 def env(self) -> str: 193 return ( 194 "Unspecified env" if self.taskspec.env is None 195 else cast(str, self.taskspec.env) 196 ) 197 198 @property 199 def as_row_name(self) -> str: 200 return self.sub_label or self.stmt or "[Unknown]" 201 202 def __repr__(self) -> str: 203 """ 204 Example repr: 205 <utils.common.Measurement object at 0x7f395b6ac110> 206 Broadcasting add (4x8) 207 Median: 5.73 us 208 IQR: 2.25 us (4.01 to 6.26) 209 372 measurements, 100 runs per measurement, 1 thread 210 WARNING: Interquartile range is 39.4% of the median measurement. 211 This suggests significant environmental influence. 212 """ 213 self._lazy_init() 214 skip_line, newline = "MEASUREMENT_REPR_SKIP_LINE", "\n" 215 n = len(self._sorted_times) 216 time_unit, time_scale = select_unit(self._median) 217 iqr_filter = '' if n >= 4 else skip_line 218 219 repr_str = f""" 220{super().__repr__()} 221{self.task_spec.summarize()} 222 {'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit} 223 {iqr_filter}IQR: {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f}) 224 {n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''} 225{newline.join(self._warnings)}""".strip() # noqa: B950 226 227 return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l) 228 229 @staticmethod 230 def merge(measurements: Iterable["Measurement"]) -> List["Measurement"]: 231 """Convenience method for merging replicates. 232 233 Merge will extrapolate times to `number_per_run=1` and will not 234 transfer any metadata. (Since it might differ between replicates) 235 """ 236 grouped_measurements: DefaultDict[TaskSpec, List[Measurement]] = collections.defaultdict(list) 237 for m in measurements: 238 grouped_measurements[m.task_spec].append(m) 239 240 def merge_group(task_spec: TaskSpec, group: List["Measurement"]) -> "Measurement": 241 times: List[float] = [] 242 for m in group: 243 # Different measurements could have different `number_per_run`, 244 # so we call `.times` which normalizes the results. 245 times.extend(m.times) 246 247 return Measurement( 248 number_per_run=1, 249 raw_times=times, 250 task_spec=task_spec, 251 metadata=None, 252 ) 253 254 return [merge_group(t, g) for t, g in grouped_measurements.items()] 255 256 257def select_unit(t: float) -> Tuple[str, float]: 258 """Determine how to scale times for O(1) magnitude. 259 260 This utility is used to format numbers for human consumption. 261 """ 262 time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(torch.tensor(t).log10().item() // 3), "s") 263 time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit] 264 return time_unit, time_scale 265 266 267def unit_to_english(u: str) -> str: 268 return { 269 "ns": "nanosecond", 270 "us": "microsecond", 271 "ms": "millisecond", 272 "s": "second", 273 }[u] 274 275 276def trim_sigfig(x: float, n: int) -> float: 277 """Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)""" 278 assert n == int(n) 279 magnitude = int(torch.tensor(x).abs().log10().ceil().item()) 280 scale = 10 ** (magnitude - n) 281 return float(torch.tensor(x / scale).round() * scale) 282 283 284def ordered_unique(elements: Iterable[Any]) -> List[Any]: 285 return list(collections.OrderedDict(dict.fromkeys(elements)).keys()) 286 287 288@contextlib.contextmanager 289def set_torch_threads(n: int) -> Iterator[None]: 290 prior_num_threads = torch.get_num_threads() 291 try: 292 torch.set_num_threads(n) 293 yield 294 finally: 295 torch.set_num_threads(prior_num_threads) 296 297 298def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> str: 299 """Create a temporary directory. The caller is responsible for cleanup. 300 301 This function is conceptually similar to `tempfile.mkdtemp`, but with 302 the key additional feature that it will use shared memory if the 303 `BENCHMARK_USE_DEV_SHM` environment variable is set. This is an 304 implementation detail, but an important one for cases where many Callgrind 305 measurements are collected at once. (Such as when collecting 306 microbenchmarks.) 307 308 This is an internal utility, and is exported solely so that microbenchmarks 309 can reuse the util. 310 """ 311 use_dev_shm: bool = (os.getenv("BENCHMARK_USE_DEV_SHM") or "").lower() in ("1", "true") 312 if use_dev_shm: 313 root = "/dev/shm/pytorch_benchmark_utils" 314 assert os.name == "posix", f"tmpfs (/dev/shm) is POSIX only, current platform is {os.name}" 315 assert os.path.exists("/dev/shm"), "This system does not appear to support tmpfs (/dev/shm)." 316 os.makedirs(root, exist_ok=True) 317 318 # Because we're working in shared memory, it is more important than 319 # usual to clean up ALL intermediate files. However we don't want every 320 # worker to walk over all outstanding directories, so instead we only 321 # check when we are sure that it won't lead to contention. 322 if gc_dev_shm: 323 for i in os.listdir(root): 324 owner_file = os.path.join(root, i, "owner.pid") 325 if not os.path.exists(owner_file): 326 continue 327 328 with open(owner_file) as f: 329 owner_pid = int(f.read()) 330 331 if owner_pid == os.getpid(): 332 continue 333 334 try: 335 # https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python 336 os.kill(owner_pid, 0) 337 338 except OSError: 339 print(f"Detected that {os.path.join(root, i)} was orphaned in shared memory. Cleaning up.") 340 shutil.rmtree(os.path.join(root, i)) 341 342 else: 343 root = tempfile.gettempdir() 344 345 # We include the time so names sort by creation time, and add a UUID 346 # to ensure we don't collide. 347 name = f"{prefix or tempfile.gettempprefix()}__{int(time.time())}__{uuid.uuid4()}" 348 path = os.path.join(root, name) 349 os.makedirs(path, exist_ok=False) 350 351 if use_dev_shm: 352 with open(os.path.join(path, "owner.pid"), "w") as f: 353 f.write(str(os.getpid())) 354 355 return path 356