xref: /aosp_15_r20/external/pytorch/benchmarks/instruction_counts/execution/work.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Handle the details of subprocess calls and retries for a given benchmark run."""
2# mypy: ignore-errors
3import dataclasses
4import json
5import os
6import pickle
7import signal
8import subprocess
9import time
10import uuid
11from typing import List, Optional, TYPE_CHECKING, Union
12
13from core.api import AutoLabels
14from core.types import Label
15from core.utils import get_temp_dir
16from worker.main import (
17    WORKER_PATH,
18    WorkerFailure,
19    WorkerOutput,
20    WorkerTimerArgs,
21    WorkerUnpickler,
22)
23
24
25if TYPE_CHECKING:
26    PopenType = subprocess.Popen[bytes]
27else:
28    PopenType = subprocess.Popen
29
30
31# Mitigate https://github.com/pytorch/pytorch/issues/37377
32_ENV = "MKL_THREADING_LAYER=GNU"
33_PYTHON = "python"
34PYTHON_CMD = f"{_ENV} {_PYTHON}"
35
36# We must specify `bash` so that `source activate ...` always works
37SHELL = "/bin/bash"
38
39
40@dataclasses.dataclass(frozen=True)
41class WorkOrder:
42    """Spec to schedule work with the benchmark runner."""
43
44    label: Label
45    autolabels: AutoLabels
46    timer_args: WorkerTimerArgs
47    source_cmd: Optional[str] = None
48    timeout: Optional[float] = None
49    retries: int = 0
50
51    def __hash__(self) -> int:
52        return id(self)
53
54    def __str__(self) -> str:
55        return json.dumps(
56            {
57                "label": self.label,
58                "autolabels": self.autolabels.as_dict,
59                "num_threads": self.timer_args.num_threads,
60            }
61        )
62
63
64class _BenchmarkProcess:
65    """Wraps subprocess.Popen for a given WorkOrder."""
66
67    _work_order: WorkOrder
68    _cpu_list: Optional[str]
69    _proc: PopenType
70
71    # Internal bookkeeping
72    _communication_file: str
73    _start_time: float
74    _end_time: Optional[float] = None
75    _retcode: Optional[int]
76    _result: Optional[Union[WorkerOutput, WorkerFailure]] = None
77
78    def __init__(self, work_order: WorkOrder, cpu_list: Optional[str]) -> None:
79        self._work_order = work_order
80        self._cpu_list = cpu_list
81        self._start_time = time.time()
82        self._communication_file = os.path.join(get_temp_dir(), f"{uuid.uuid4()}.pkl")
83        with open(self._communication_file, "wb") as f:
84            pickle.dump(self._work_order.timer_args, f)
85
86        self._proc = subprocess.Popen(
87            self.cmd,
88            stdout=subprocess.PIPE,
89            stderr=subprocess.STDOUT,
90            shell=True,
91            executable=SHELL,
92        )
93
94    def clone(self) -> "_BenchmarkProcess":
95        return _BenchmarkProcess(self._work_order, self._cpu_list)
96
97    @property
98    def cmd(self) -> str:
99        cmd: List[str] = []
100        if self._work_order.source_cmd is not None:
101            cmd.extend([self._work_order.source_cmd, "&&"])
102
103        cmd.append(_ENV)
104
105        if self._cpu_list is not None:
106            cmd.extend(
107                [
108                    f"GOMP_CPU_AFFINITY={self._cpu_list}",
109                    "taskset",
110                    "--cpu-list",
111                    self._cpu_list,
112                ]
113            )
114
115        cmd.extend(
116            [
117                _PYTHON,
118                WORKER_PATH,
119                "--communication-file",
120                self._communication_file,
121            ]
122        )
123        return " ".join(cmd)
124
125    @property
126    def duration(self) -> float:
127        return (self._end_time or time.time()) - self._start_time
128
129    @property
130    def result(self) -> Union[WorkerOutput, WorkerFailure]:
131        self._maybe_collect()
132        assert self._result is not None
133        return self._result
134
135    def poll(self) -> Optional[int]:
136        self._maybe_collect()
137        return self._retcode
138
139    def interrupt(self) -> None:
140        """Soft interrupt. Allows subprocess to cleanup."""
141        self._proc.send_signal(signal.SIGINT)
142
143    def terminate(self) -> None:
144        """Hard interrupt. Immediately SIGTERM subprocess."""
145        self._proc.terminate()
146
147    def _maybe_collect(self) -> None:
148        if self._result is not None:
149            # We've already collected the results.
150            return
151
152        self._retcode = self._proc.poll()
153        if self._retcode is None:
154            # `_proc` is still running
155            return
156
157        with open(self._communication_file, "rb") as f:
158            result = WorkerUnpickler(f).load_output()
159
160        if isinstance(result, WorkerOutput) and self._retcode:
161            # Worker managed to complete the designated task, but worker
162            # process did not finish cleanly.
163            result = WorkerFailure("Worker failed silently.")
164
165        if isinstance(result, WorkerTimerArgs):
166            # Worker failed, but did not write a result so we're left with the
167            # original TimerArgs. Grabbing all of stdout and stderr isn't
168            # ideal, but we don't have a better way to determine what to keep.
169            proc_stdout = self._proc.stdout
170            assert proc_stdout is not None
171            result = WorkerFailure(failure_trace=proc_stdout.read().decode("utf-8"))
172
173        self._result = result
174        self._end_time = time.time()
175
176        # Release communication file.
177        os.remove(self._communication_file)
178
179
180class InProgress:
181    """Used by the benchmark runner to track outstanding jobs.
182    This class handles bookkeeping and timeout + retry logic.
183    """
184
185    _proc: _BenchmarkProcess
186    _timeouts: int = 0
187
188    def __init__(self, work_order: WorkOrder, cpu_list: Optional[str]):
189        self._work_order = work_order
190        self._proc = _BenchmarkProcess(work_order, cpu_list)
191
192    @property
193    def work_order(self) -> WorkOrder:
194        return self._proc._work_order
195
196    @property
197    def cpu_list(self) -> Optional[str]:
198        return self._proc._cpu_list
199
200    @property
201    def proc(self) -> _BenchmarkProcess:
202        # NB: For cleanup only.
203        return self._proc
204
205    @property
206    def duration(self) -> float:
207        return self._proc.duration
208
209    def check_finished(self) -> bool:
210        if self._proc.poll() is not None:
211            return True
212
213        timeout = self.work_order.timeout
214        if timeout is None or self._proc.duration < timeout:
215            return False
216
217        self._timeouts += 1
218        max_attempts = (self._work_order.retries or 0) + 1
219        if self._timeouts < max_attempts:
220            print(
221                f"\nTimeout: {self._work_order.label}, {self._work_order.autolabels} "
222                f"(Attempt {self._timeouts} / {max_attempts})"
223            )
224            self._proc.interrupt()
225            self._proc = self._proc.clone()
226            return False
227
228        raise subprocess.TimeoutExpired(cmd=self._proc.cmd, timeout=timeout)
229
230    @property
231    def result(self) -> Union[WorkerOutput, WorkerFailure]:
232        return self._proc.result
233
234    def __hash__(self) -> int:
235        return id(self)
236