1"""Handle the details of subprocess calls and retries for a given benchmark run.""" 2# mypy: ignore-errors 3import dataclasses 4import json 5import os 6import pickle 7import signal 8import subprocess 9import time 10import uuid 11from typing import List, Optional, TYPE_CHECKING, Union 12 13from core.api import AutoLabels 14from core.types import Label 15from core.utils import get_temp_dir 16from worker.main import ( 17 WORKER_PATH, 18 WorkerFailure, 19 WorkerOutput, 20 WorkerTimerArgs, 21 WorkerUnpickler, 22) 23 24 25if TYPE_CHECKING: 26 PopenType = subprocess.Popen[bytes] 27else: 28 PopenType = subprocess.Popen 29 30 31# Mitigate https://github.com/pytorch/pytorch/issues/37377 32_ENV = "MKL_THREADING_LAYER=GNU" 33_PYTHON = "python" 34PYTHON_CMD = f"{_ENV} {_PYTHON}" 35 36# We must specify `bash` so that `source activate ...` always works 37SHELL = "/bin/bash" 38 39 40@dataclasses.dataclass(frozen=True) 41class WorkOrder: 42 """Spec to schedule work with the benchmark runner.""" 43 44 label: Label 45 autolabels: AutoLabels 46 timer_args: WorkerTimerArgs 47 source_cmd: Optional[str] = None 48 timeout: Optional[float] = None 49 retries: int = 0 50 51 def __hash__(self) -> int: 52 return id(self) 53 54 def __str__(self) -> str: 55 return json.dumps( 56 { 57 "label": self.label, 58 "autolabels": self.autolabels.as_dict, 59 "num_threads": self.timer_args.num_threads, 60 } 61 ) 62 63 64class _BenchmarkProcess: 65 """Wraps subprocess.Popen for a given WorkOrder.""" 66 67 _work_order: WorkOrder 68 _cpu_list: Optional[str] 69 _proc: PopenType 70 71 # Internal bookkeeping 72 _communication_file: str 73 _start_time: float 74 _end_time: Optional[float] = None 75 _retcode: Optional[int] 76 _result: Optional[Union[WorkerOutput, WorkerFailure]] = None 77 78 def __init__(self, work_order: WorkOrder, cpu_list: Optional[str]) -> None: 79 self._work_order = work_order 80 self._cpu_list = cpu_list 81 self._start_time = time.time() 82 self._communication_file = os.path.join(get_temp_dir(), f"{uuid.uuid4()}.pkl") 83 with open(self._communication_file, "wb") as f: 84 pickle.dump(self._work_order.timer_args, f) 85 86 self._proc = subprocess.Popen( 87 self.cmd, 88 stdout=subprocess.PIPE, 89 stderr=subprocess.STDOUT, 90 shell=True, 91 executable=SHELL, 92 ) 93 94 def clone(self) -> "_BenchmarkProcess": 95 return _BenchmarkProcess(self._work_order, self._cpu_list) 96 97 @property 98 def cmd(self) -> str: 99 cmd: List[str] = [] 100 if self._work_order.source_cmd is not None: 101 cmd.extend([self._work_order.source_cmd, "&&"]) 102 103 cmd.append(_ENV) 104 105 if self._cpu_list is not None: 106 cmd.extend( 107 [ 108 f"GOMP_CPU_AFFINITY={self._cpu_list}", 109 "taskset", 110 "--cpu-list", 111 self._cpu_list, 112 ] 113 ) 114 115 cmd.extend( 116 [ 117 _PYTHON, 118 WORKER_PATH, 119 "--communication-file", 120 self._communication_file, 121 ] 122 ) 123 return " ".join(cmd) 124 125 @property 126 def duration(self) -> float: 127 return (self._end_time or time.time()) - self._start_time 128 129 @property 130 def result(self) -> Union[WorkerOutput, WorkerFailure]: 131 self._maybe_collect() 132 assert self._result is not None 133 return self._result 134 135 def poll(self) -> Optional[int]: 136 self._maybe_collect() 137 return self._retcode 138 139 def interrupt(self) -> None: 140 """Soft interrupt. Allows subprocess to cleanup.""" 141 self._proc.send_signal(signal.SIGINT) 142 143 def terminate(self) -> None: 144 """Hard interrupt. Immediately SIGTERM subprocess.""" 145 self._proc.terminate() 146 147 def _maybe_collect(self) -> None: 148 if self._result is not None: 149 # We've already collected the results. 150 return 151 152 self._retcode = self._proc.poll() 153 if self._retcode is None: 154 # `_proc` is still running 155 return 156 157 with open(self._communication_file, "rb") as f: 158 result = WorkerUnpickler(f).load_output() 159 160 if isinstance(result, WorkerOutput) and self._retcode: 161 # Worker managed to complete the designated task, but worker 162 # process did not finish cleanly. 163 result = WorkerFailure("Worker failed silently.") 164 165 if isinstance(result, WorkerTimerArgs): 166 # Worker failed, but did not write a result so we're left with the 167 # original TimerArgs. Grabbing all of stdout and stderr isn't 168 # ideal, but we don't have a better way to determine what to keep. 169 proc_stdout = self._proc.stdout 170 assert proc_stdout is not None 171 result = WorkerFailure(failure_trace=proc_stdout.read().decode("utf-8")) 172 173 self._result = result 174 self._end_time = time.time() 175 176 # Release communication file. 177 os.remove(self._communication_file) 178 179 180class InProgress: 181 """Used by the benchmark runner to track outstanding jobs. 182 This class handles bookkeeping and timeout + retry logic. 183 """ 184 185 _proc: _BenchmarkProcess 186 _timeouts: int = 0 187 188 def __init__(self, work_order: WorkOrder, cpu_list: Optional[str]): 189 self._work_order = work_order 190 self._proc = _BenchmarkProcess(work_order, cpu_list) 191 192 @property 193 def work_order(self) -> WorkOrder: 194 return self._proc._work_order 195 196 @property 197 def cpu_list(self) -> Optional[str]: 198 return self._proc._cpu_list 199 200 @property 201 def proc(self) -> _BenchmarkProcess: 202 # NB: For cleanup only. 203 return self._proc 204 205 @property 206 def duration(self) -> float: 207 return self._proc.duration 208 209 def check_finished(self) -> bool: 210 if self._proc.poll() is not None: 211 return True 212 213 timeout = self.work_order.timeout 214 if timeout is None or self._proc.duration < timeout: 215 return False 216 217 self._timeouts += 1 218 max_attempts = (self._work_order.retries or 0) + 1 219 if self._timeouts < max_attempts: 220 print( 221 f"\nTimeout: {self._work_order.label}, {self._work_order.autolabels} " 222 f"(Attempt {self._timeouts} / {max_attempts})" 223 ) 224 self._proc.interrupt() 225 self._proc = self._proc.clone() 226 return False 227 228 raise subprocess.TimeoutExpired(cmd=self._proc.cmd, timeout=timeout) 229 230 @property 231 def result(self) -> Union[WorkerOutput, WorkerFailure]: 232 return self._proc.result 233 234 def __hash__(self) -> int: 235 return id(self) 236