xref: /aosp_15_r20/external/pytorch/benchmarks/instruction_counts/worker/main.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""File invoked through subprocess to actually carry out measurements.
2
3`worker/main.py` is deliberately isolated from the rest of the benchmark
4infrastructure. Other parts of the benchmark rely on this file, but
5`worker/` has only one Python file and does not import ANYTHING from the rest
6of the benchmark suite. The reason that this is important is that we can't
7rely on paths to access the other files (namely `core.api`) since a source
8command might change the CWD. It also helps keep startup time down by limiting
9spurious definition work.
10
11The life of a worker is very simple:
12    It receives a file containing a `WorkerTimerArgs` telling it what to run,
13    and writes a `WorkerOutput` result back to the same file.
14
15Because this file only expects to run in a child context, error handling means
16plumbing failures up to the caller, not raising in this process.
17"""
18import argparse
19import dataclasses
20import io
21import os
22import pickle
23import sys
24import timeit
25import traceback
26from typing import Any, Tuple, TYPE_CHECKING, Union
27
28
29if TYPE_CHECKING:
30    # Benchmark utils are only partially strict compliant, so MyPy won't follow
31    # imports using the public namespace. (Due to an exclusion rule in
32    # mypy-strict.ini)
33    from torch.utils.benchmark.utils.timer import Language, Timer
34    from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import (
35        CallgrindStats,
36    )
37
38else:
39    from torch.utils.benchmark import CallgrindStats, Language, Timer
40
41
42WORKER_PATH = os.path.abspath(__file__)
43
44
45# =============================================================================
46# == Interface ================================================================
47# =============================================================================
48
49# While the point of this is mainly to collect instruction counts, we're going
50# to have to compile C++ timers anyway (as they're used as a check before
51# calling Valgrind), so we may as well grab wall times for reference. They
52# are comparatively inexpensive.
53MIN_RUN_TIME = 5
54
55# Repeats are inexpensive as long as they are all run in the same process. This
56# also lets us filter outliers (e.g. malloc arena reorganization), so we don't
57# need a high CALLGRIND_NUMBER to get good data.
58CALLGRIND_NUMBER = 100
59CALLGRIND_REPEATS = 5
60
61
62@dataclasses.dataclass(frozen=True)
63class WorkerTimerArgs:
64    """Container for Timer constructor arguments.
65
66    This dataclass serves two roles. First, it is a simple interface for
67    defining benchmarks. (See core.api.GroupedStmts and core.api.GroupedModules
68    for the advanced interfaces.) Second, it provides serialization for
69    controlling workers. `Timer` is not pickleable, so instead the main process
70    will pass `WorkerTimerArgs` instances to workers for processing.
71    """
72
73    stmt: str
74    setup: str = "pass"
75    global_setup: str = ""
76    num_threads: int = 1
77    language: Language = Language.PYTHON
78
79
80@dataclasses.dataclass(frozen=True)
81class WorkerOutput:
82    # Only return values to reduce communication between main process and workers.
83    wall_times: Tuple[float, ...]
84    instructions: Tuple[int, ...]
85
86
87@dataclasses.dataclass(frozen=True)
88class WorkerFailure:
89    # If a worker fails, we attach the string contents of the Exception
90    # rather than the Exception object itself. This is done for two reasons:
91    #   1) Depending on the type thrown, `e` may or may not be pickleable
92    #   2) If we re-throw in the main process, we lose the true stack trace.
93    failure_trace: str
94
95
96class WorkerUnpickler(pickle.Unpickler):
97    def find_class(self, module: str, name: str) -> Any:
98        """Resolve import for pickle.
99
100        When the main runner uses a symbol `foo` from this file, it sees it as
101        `worker.main.foo`. However the worker (called as a standalone file)
102        sees the same symbol as `__main__.foo`. We have to help pickle
103        understand that they refer to the same symbols.
104        """
105        symbol_map = {
106            # Only blessed interface Enums and dataclasses need to be mapped.
107            "WorkerTimerArgs": WorkerTimerArgs,
108            "WorkerOutput": WorkerOutput,
109            "WorkerFailure": WorkerFailure,
110        }
111
112        if name in symbol_map:
113            return symbol_map[name]
114
115        return super().find_class(module, name)
116
117    def load_input(self) -> WorkerTimerArgs:
118        result = self.load()
119        assert isinstance(result, WorkerTimerArgs)
120        return result
121
122    def load_output(self) -> Union[WorkerTimerArgs, WorkerOutput, WorkerFailure]:
123        """Convenience method for type safe loading."""
124        result = self.load()
125        assert isinstance(result, (WorkerTimerArgs, WorkerOutput, WorkerFailure))
126        return result
127
128
129# =============================================================================
130# == Execution ================================================================
131# =============================================================================
132
133
134def _run(timer_args: WorkerTimerArgs) -> WorkerOutput:
135    timer = Timer(
136        stmt=timer_args.stmt,
137        setup=timer_args.setup or "pass",
138        global_setup=timer_args.global_setup,
139        # Prevent NotImplementedError on GPU builds and C++ snippets.
140        timer=timeit.default_timer,
141        num_threads=timer_args.num_threads,
142        language=timer_args.language,
143    )
144
145    m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME)
146
147    stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind(
148        number=CALLGRIND_NUMBER,
149        collect_baseline=False,
150        repeats=CALLGRIND_REPEATS,
151        retain_out_file=False,
152    )
153
154    return WorkerOutput(
155        wall_times=tuple(m.times),
156        instructions=tuple(s.counts(denoise=True) for s in stats),
157    )
158
159
160def main(communication_file: str) -> None:
161    result: Union[WorkerOutput, WorkerFailure]
162    try:
163        with open(communication_file, "rb") as f:
164            timer_args: WorkerTimerArgs = WorkerUnpickler(f).load_input()
165            assert isinstance(timer_args, WorkerTimerArgs)
166        result = _run(timer_args)
167
168    except KeyboardInterrupt:
169        # Runner process sent SIGINT.
170        sys.exit()
171
172    except BaseException:
173        trace_f = io.StringIO()
174        traceback.print_exc(file=trace_f)
175        result = WorkerFailure(failure_trace=trace_f.getvalue())
176
177    if not os.path.exists(os.path.split(communication_file)[0]):
178        # This worker is an orphan, and the parent has already cleaned up the
179        # working directory. In that case we can simply exit.
180        print(f"Orphaned worker {os.getpid()} exiting.")
181        return
182
183    with open(communication_file, "wb") as f:
184        pickle.dump(result, f)
185
186
187if __name__ == "__main__":
188    parser = argparse.ArgumentParser()
189    parser.add_argument("--communication-file", "--communication_file", type=str)
190    communication_file = parser.parse_args().communication_file
191    main(communication_file)
192