1"""File invoked through subprocess to actually carry out measurements. 2 3`worker/main.py` is deliberately isolated from the rest of the benchmark 4infrastructure. Other parts of the benchmark rely on this file, but 5`worker/` has only one Python file and does not import ANYTHING from the rest 6of the benchmark suite. The reason that this is important is that we can't 7rely on paths to access the other files (namely `core.api`) since a source 8command might change the CWD. It also helps keep startup time down by limiting 9spurious definition work. 10 11The life of a worker is very simple: 12 It receives a file containing a `WorkerTimerArgs` telling it what to run, 13 and writes a `WorkerOutput` result back to the same file. 14 15Because this file only expects to run in a child context, error handling means 16plumbing failures up to the caller, not raising in this process. 17""" 18import argparse 19import dataclasses 20import io 21import os 22import pickle 23import sys 24import timeit 25import traceback 26from typing import Any, Tuple, TYPE_CHECKING, Union 27 28 29if TYPE_CHECKING: 30 # Benchmark utils are only partially strict compliant, so MyPy won't follow 31 # imports using the public namespace. (Due to an exclusion rule in 32 # mypy-strict.ini) 33 from torch.utils.benchmark.utils.timer import Language, Timer 34 from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import ( 35 CallgrindStats, 36 ) 37 38else: 39 from torch.utils.benchmark import CallgrindStats, Language, Timer 40 41 42WORKER_PATH = os.path.abspath(__file__) 43 44 45# ============================================================================= 46# == Interface ================================================================ 47# ============================================================================= 48 49# While the point of this is mainly to collect instruction counts, we're going 50# to have to compile C++ timers anyway (as they're used as a check before 51# calling Valgrind), so we may as well grab wall times for reference. They 52# are comparatively inexpensive. 53MIN_RUN_TIME = 5 54 55# Repeats are inexpensive as long as they are all run in the same process. This 56# also lets us filter outliers (e.g. malloc arena reorganization), so we don't 57# need a high CALLGRIND_NUMBER to get good data. 58CALLGRIND_NUMBER = 100 59CALLGRIND_REPEATS = 5 60 61 62@dataclasses.dataclass(frozen=True) 63class WorkerTimerArgs: 64 """Container for Timer constructor arguments. 65 66 This dataclass serves two roles. First, it is a simple interface for 67 defining benchmarks. (See core.api.GroupedStmts and core.api.GroupedModules 68 for the advanced interfaces.) Second, it provides serialization for 69 controlling workers. `Timer` is not pickleable, so instead the main process 70 will pass `WorkerTimerArgs` instances to workers for processing. 71 """ 72 73 stmt: str 74 setup: str = "pass" 75 global_setup: str = "" 76 num_threads: int = 1 77 language: Language = Language.PYTHON 78 79 80@dataclasses.dataclass(frozen=True) 81class WorkerOutput: 82 # Only return values to reduce communication between main process and workers. 83 wall_times: Tuple[float, ...] 84 instructions: Tuple[int, ...] 85 86 87@dataclasses.dataclass(frozen=True) 88class WorkerFailure: 89 # If a worker fails, we attach the string contents of the Exception 90 # rather than the Exception object itself. This is done for two reasons: 91 # 1) Depending on the type thrown, `e` may or may not be pickleable 92 # 2) If we re-throw in the main process, we lose the true stack trace. 93 failure_trace: str 94 95 96class WorkerUnpickler(pickle.Unpickler): 97 def find_class(self, module: str, name: str) -> Any: 98 """Resolve import for pickle. 99 100 When the main runner uses a symbol `foo` from this file, it sees it as 101 `worker.main.foo`. However the worker (called as a standalone file) 102 sees the same symbol as `__main__.foo`. We have to help pickle 103 understand that they refer to the same symbols. 104 """ 105 symbol_map = { 106 # Only blessed interface Enums and dataclasses need to be mapped. 107 "WorkerTimerArgs": WorkerTimerArgs, 108 "WorkerOutput": WorkerOutput, 109 "WorkerFailure": WorkerFailure, 110 } 111 112 if name in symbol_map: 113 return symbol_map[name] 114 115 return super().find_class(module, name) 116 117 def load_input(self) -> WorkerTimerArgs: 118 result = self.load() 119 assert isinstance(result, WorkerTimerArgs) 120 return result 121 122 def load_output(self) -> Union[WorkerTimerArgs, WorkerOutput, WorkerFailure]: 123 """Convenience method for type safe loading.""" 124 result = self.load() 125 assert isinstance(result, (WorkerTimerArgs, WorkerOutput, WorkerFailure)) 126 return result 127 128 129# ============================================================================= 130# == Execution ================================================================ 131# ============================================================================= 132 133 134def _run(timer_args: WorkerTimerArgs) -> WorkerOutput: 135 timer = Timer( 136 stmt=timer_args.stmt, 137 setup=timer_args.setup or "pass", 138 global_setup=timer_args.global_setup, 139 # Prevent NotImplementedError on GPU builds and C++ snippets. 140 timer=timeit.default_timer, 141 num_threads=timer_args.num_threads, 142 language=timer_args.language, 143 ) 144 145 m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME) 146 147 stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind( 148 number=CALLGRIND_NUMBER, 149 collect_baseline=False, 150 repeats=CALLGRIND_REPEATS, 151 retain_out_file=False, 152 ) 153 154 return WorkerOutput( 155 wall_times=tuple(m.times), 156 instructions=tuple(s.counts(denoise=True) for s in stats), 157 ) 158 159 160def main(communication_file: str) -> None: 161 result: Union[WorkerOutput, WorkerFailure] 162 try: 163 with open(communication_file, "rb") as f: 164 timer_args: WorkerTimerArgs = WorkerUnpickler(f).load_input() 165 assert isinstance(timer_args, WorkerTimerArgs) 166 result = _run(timer_args) 167 168 except KeyboardInterrupt: 169 # Runner process sent SIGINT. 170 sys.exit() 171 172 except BaseException: 173 trace_f = io.StringIO() 174 traceback.print_exc(file=trace_f) 175 result = WorkerFailure(failure_trace=trace_f.getvalue()) 176 177 if not os.path.exists(os.path.split(communication_file)[0]): 178 # This worker is an orphan, and the parent has already cleaned up the 179 # working directory. In that case we can simply exit. 180 print(f"Orphaned worker {os.getpid()} exiting.") 181 return 182 183 with open(communication_file, "wb") as f: 184 pickle.dump(result, f) 185 186 187if __name__ == "__main__": 188 parser = argparse.ArgumentParser() 189 parser.add_argument("--communication-file", "--communication_file", type=str) 190 communication_file = parser.parse_args().communication_file 191 main(communication_file) 192