xref: /aosp_15_r20/external/pytorch/torch/utils/benchmark/utils/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Base shared classes and utilities."""
2
3import collections
4import contextlib
5import dataclasses
6import os
7import shutil
8import tempfile
9import textwrap
10import time
11from typing import cast, Any, DefaultDict, Dict, Iterable, Iterator, List, Optional, Tuple
12import uuid
13
14import torch
15
16
17__all__ = ["TaskSpec", "Measurement", "select_unit", "unit_to_english", "trim_sigfig", "ordered_unique", "set_torch_threads"]
18
19
20_MAX_SIGNIFICANT_FIGURES = 4
21_MIN_CONFIDENCE_INTERVAL = 25e-9  # 25 ns
22
23# Measurement will include a warning if the distribution is suspect. All
24# runs are expected to have some variation; these parameters set the
25# thresholds.
26_IQR_WARN_THRESHOLD = 0.1
27_IQR_GROSS_WARN_THRESHOLD = 0.25
28
29
30@dataclasses.dataclass(init=True, repr=False, eq=True, frozen=True)
31class TaskSpec:
32    """Container for information used to define a Timer. (except globals)"""
33    stmt: str
34    setup: str
35    global_setup: str = ""
36    label: Optional[str] = None
37    sub_label: Optional[str] = None
38    description: Optional[str] = None
39    env: Optional[str] = None
40    num_threads: int = 1
41
42    @property
43    def title(self) -> str:
44        """Best effort attempt at a string label for the measurement."""
45        if self.label is not None:
46            return self.label + (f": {self.sub_label}" if self.sub_label else "")
47        elif "\n" not in self.stmt:
48            return self.stmt + (f": {self.sub_label}" if self.sub_label else "")
49        return (
50            f"stmt:{f' ({self.sub_label})' if self.sub_label else ''}\n"
51            f"{textwrap.indent(self.stmt, '  ')}"
52        )
53
54    def setup_str(self) -> str:
55        return (
56            "" if (self.setup == "pass" or not self.setup)
57            else f"setup:\n{textwrap.indent(self.setup, '  ')}" if "\n" in self.setup
58            else f"setup: {self.setup}"
59        )
60
61    def summarize(self) -> str:
62        """Build TaskSpec portion of repr string for other containers."""
63        sections = [
64            self.title,
65            self.description or "",
66            self.setup_str(),
67        ]
68        return "\n".join([f"{i}\n" if "\n" in i else i for i in sections if i])
69
70_TASKSPEC_FIELDS = tuple(i.name for i in dataclasses.fields(TaskSpec))
71
72
73@dataclasses.dataclass(init=True, repr=False)
74class Measurement:
75    """The result of a Timer measurement.
76
77    This class stores one or more measurements of a given statement. It is
78    serializable and provides several convenience methods
79    (including a detailed __repr__) for downstream consumers.
80    """
81    number_per_run: int
82    raw_times: List[float]
83    task_spec: TaskSpec
84    metadata: Optional[Dict[Any, Any]] = None  # Reserved for user payloads.
85
86    def __post_init__(self) -> None:
87        self._sorted_times: Tuple[float, ...] = ()
88        self._warnings: Tuple[str, ...] = ()
89        self._median: float = -1.0
90        self._mean: float = -1.0
91        self._p25: float = -1.0
92        self._p75: float = -1.0
93
94    def __getattr__(self, name: str) -> Any:
95        # Forward TaskSpec fields for convenience.
96        if name in _TASKSPEC_FIELDS:
97            return getattr(self.task_spec, name)
98        return super().__getattribute__(name)
99
100    # =========================================================================
101    # == Convenience methods for statistics ===================================
102    # =========================================================================
103    #
104    # These methods use raw time divided by number_per_run; this is an
105    # extrapolation and hides the fact that different number_per_run will
106    # result in different amortization of overheads, however if Timer has
107    # selected an appropriate number_per_run then this is a non-issue, and
108    # forcing users to handle that division would result in a poor experience.
109    @property
110    def times(self) -> List[float]:
111        return [t / self.number_per_run for t in self.raw_times]
112
113    @property
114    def median(self) -> float:
115        self._lazy_init()
116        return self._median
117
118    @property
119    def mean(self) -> float:
120        self._lazy_init()
121        return self._mean
122
123    @property
124    def iqr(self) -> float:
125        self._lazy_init()
126        return self._p75 - self._p25
127
128    @property
129    def significant_figures(self) -> int:
130        """Approximate significant figure estimate.
131
132        This property is intended to give a convenient way to estimate the
133        precision of a measurement. It only uses the interquartile region to
134        estimate statistics to try to mitigate skew from the tails, and
135        uses a static z value of 1.645 since it is not expected to be used
136        for small values of `n`, so z can approximate `t`.
137
138        The significant figure estimation used in conjunction with the
139        `trim_sigfig` method to provide a more human interpretable data
140        summary. __repr__ does not use this method; it simply displays raw
141        values. Significant figure estimation is intended for `Compare`.
142        """
143        self._lazy_init()
144        n_total = len(self._sorted_times)
145        lower_bound = int(n_total // 4)
146        upper_bound = int(torch.tensor(3 * n_total / 4).ceil())
147        interquartile_points: Tuple[float, ...] = self._sorted_times[lower_bound:upper_bound]
148        std = torch.tensor(interquartile_points).std(unbiased=False).item()
149        sqrt_n = torch.tensor(len(interquartile_points)).sqrt().item()
150
151        # Rough estimates. These are by no means statistically rigorous.
152        confidence_interval = max(1.645 * std / sqrt_n, _MIN_CONFIDENCE_INTERVAL)
153        relative_ci = torch.tensor(self._median / confidence_interval).log10().item()
154        num_significant_figures = int(torch.tensor(relative_ci).floor())
155        return min(max(num_significant_figures, 1), _MAX_SIGNIFICANT_FIGURES)
156
157    @property
158    def has_warnings(self) -> bool:
159        self._lazy_init()
160        return bool(self._warnings)
161
162    def _lazy_init(self) -> None:
163        if self.raw_times and not self._sorted_times:
164            self._sorted_times = tuple(sorted(self.times))
165            _sorted_times = torch.tensor(self._sorted_times, dtype=torch.float64)
166            self._median = _sorted_times.quantile(.5).item()
167            self._mean = _sorted_times.mean().item()
168            self._p25 = _sorted_times.quantile(.25).item()
169            self._p75 = _sorted_times.quantile(.75).item()
170
171            def add_warning(msg: str) -> None:
172                rel_iqr = self.iqr / self.median * 100
173                self._warnings += (
174                    f"  WARNING: Interquartile range is {rel_iqr:.1f}% "
175                    f"of the median measurement.\n           {msg}",
176                )
177
178            if not self.meets_confidence(_IQR_GROSS_WARN_THRESHOLD):
179                add_warning("This suggests significant environmental influence.")
180            elif not self.meets_confidence(_IQR_WARN_THRESHOLD):
181                add_warning("This could indicate system fluctuation.")
182
183
184    def meets_confidence(self, threshold: float = _IQR_WARN_THRESHOLD) -> bool:
185        return self.iqr / self.median < threshold
186
187    @property
188    def title(self) -> str:
189        return self.task_spec.title
190
191    @property
192    def env(self) -> str:
193        return (
194            "Unspecified env" if self.taskspec.env is None
195            else cast(str, self.taskspec.env)
196        )
197
198    @property
199    def as_row_name(self) -> str:
200        return self.sub_label or self.stmt or "[Unknown]"
201
202    def __repr__(self) -> str:
203        """
204        Example repr:
205            <utils.common.Measurement object at 0x7f395b6ac110>
206              Broadcasting add (4x8)
207              Median: 5.73 us
208              IQR:    2.25 us (4.01 to 6.26)
209              372 measurements, 100 runs per measurement, 1 thread
210              WARNING: Interquartile range is 39.4% of the median measurement.
211                       This suggests significant environmental influence.
212        """
213        self._lazy_init()
214        skip_line, newline = "MEASUREMENT_REPR_SKIP_LINE", "\n"
215        n = len(self._sorted_times)
216        time_unit, time_scale = select_unit(self._median)
217        iqr_filter = '' if n >= 4 else skip_line
218
219        repr_str = f"""
220{super().__repr__()}
221{self.task_spec.summarize()}
222  {'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit}
223  {iqr_filter}IQR:    {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f})
224  {n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''}
225{newline.join(self._warnings)}""".strip()  # noqa: B950
226
227        return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l)
228
229    @staticmethod
230    def merge(measurements: Iterable["Measurement"]) -> List["Measurement"]:
231        """Convenience method for merging replicates.
232
233        Merge will extrapolate times to `number_per_run=1` and will not
234        transfer any metadata. (Since it might differ between replicates)
235        """
236        grouped_measurements: DefaultDict[TaskSpec, List[Measurement]] = collections.defaultdict(list)
237        for m in measurements:
238            grouped_measurements[m.task_spec].append(m)
239
240        def merge_group(task_spec: TaskSpec, group: List["Measurement"]) -> "Measurement":
241            times: List[float] = []
242            for m in group:
243                # Different measurements could have different `number_per_run`,
244                # so we call `.times` which normalizes the results.
245                times.extend(m.times)
246
247            return Measurement(
248                number_per_run=1,
249                raw_times=times,
250                task_spec=task_spec,
251                metadata=None,
252            )
253
254        return [merge_group(t, g) for t, g in grouped_measurements.items()]
255
256
257def select_unit(t: float) -> Tuple[str, float]:
258    """Determine how to scale times for O(1) magnitude.
259
260    This utility is used to format numbers for human consumption.
261    """
262    time_unit = {-3: "ns", -2: "us", -1: "ms"}.get(int(torch.tensor(t).log10().item() // 3), "s")
263    time_scale = {"ns": 1e-9, "us": 1e-6, "ms": 1e-3, "s": 1}[time_unit]
264    return time_unit, time_scale
265
266
267def unit_to_english(u: str) -> str:
268    return {
269        "ns": "nanosecond",
270        "us": "microsecond",
271        "ms": "millisecond",
272        "s": "second",
273    }[u]
274
275
276def trim_sigfig(x: float, n: int) -> float:
277    """Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)"""
278    assert n == int(n)
279    magnitude = int(torch.tensor(x).abs().log10().ceil().item())
280    scale = 10 ** (magnitude - n)
281    return float(torch.tensor(x / scale).round() * scale)
282
283
284def ordered_unique(elements: Iterable[Any]) -> List[Any]:
285    return list(collections.OrderedDict(dict.fromkeys(elements)).keys())
286
287
288@contextlib.contextmanager
289def set_torch_threads(n: int) -> Iterator[None]:
290    prior_num_threads = torch.get_num_threads()
291    try:
292        torch.set_num_threads(n)
293        yield
294    finally:
295        torch.set_num_threads(prior_num_threads)
296
297
298def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> str:
299    """Create a temporary directory. The caller is responsible for cleanup.
300
301    This function is conceptually similar to `tempfile.mkdtemp`, but with
302    the key additional feature that it will use shared memory if the
303    `BENCHMARK_USE_DEV_SHM` environment variable is set. This is an
304    implementation detail, but an important one for cases where many Callgrind
305    measurements are collected at once. (Such as when collecting
306    microbenchmarks.)
307
308    This is an internal utility, and is exported solely so that microbenchmarks
309    can reuse the util.
310    """
311    use_dev_shm: bool = (os.getenv("BENCHMARK_USE_DEV_SHM") or "").lower() in ("1", "true")
312    if use_dev_shm:
313        root = "/dev/shm/pytorch_benchmark_utils"
314        assert os.name == "posix", f"tmpfs (/dev/shm) is POSIX only, current platform is {os.name}"
315        assert os.path.exists("/dev/shm"), "This system does not appear to support tmpfs (/dev/shm)."
316        os.makedirs(root, exist_ok=True)
317
318        # Because we're working in shared memory, it is more important than
319        # usual to clean up ALL intermediate files. However we don't want every
320        # worker to walk over all outstanding directories, so instead we only
321        # check when we are sure that it won't lead to contention.
322        if gc_dev_shm:
323            for i in os.listdir(root):
324                owner_file = os.path.join(root, i, "owner.pid")
325                if not os.path.exists(owner_file):
326                    continue
327
328                with open(owner_file) as f:
329                    owner_pid = int(f.read())
330
331                if owner_pid == os.getpid():
332                    continue
333
334                try:
335                    # https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-python
336                    os.kill(owner_pid, 0)
337
338                except OSError:
339                    print(f"Detected that {os.path.join(root, i)} was orphaned in shared memory. Cleaning up.")
340                    shutil.rmtree(os.path.join(root, i))
341
342    else:
343        root = tempfile.gettempdir()
344
345    # We include the time so names sort by creation time, and add a UUID
346    # to ensure we don't collide.
347    name = f"{prefix or tempfile.gettempprefix()}__{int(time.time())}__{uuid.uuid4()}"
348    path = os.path.join(root, name)
349    os.makedirs(path, exist_ok=False)
350
351    if use_dev_shm:
352        with open(os.path.join(path, "owner.pid"), "w") as f:
353            f.write(str(os.getpid()))
354
355    return path
356