xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/spawn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import multiprocessing
4import multiprocessing.connection
5import os
6import pickle
7import signal
8import sys
9import tempfile
10import time
11import warnings
12from concurrent.futures import as_completed, ThreadPoolExecutor
13from typing import Optional
14
15from . import _prctl_pr_set_pdeathsig  # type: ignore[attr-defined]
16
17
18ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
19
20log = logging.getLogger(__name__)
21
22__all__ = [
23    "ProcessContext",
24    "ProcessException",
25    "ProcessExitedException",
26    "ProcessRaisedException",
27    "spawn",
28    "SpawnContext",
29    "start_processes",
30]
31
32
33class ProcessException(Exception):
34    __slots__ = ["error_index", "error_pid"]
35
36    def __init__(self, msg: str, error_index: int, pid: int):
37        super().__init__(msg)
38        self.msg = msg
39        self.error_index = error_index
40        self.pid = pid
41
42    def __reduce__(self):
43        return type(self), (self.msg, self.error_index, self.pid)
44
45
46class ProcessRaisedException(ProcessException):
47    """Exception raised when a process failed due to an exception raised by the code."""
48
49    def __init__(
50        self,
51        msg: str,
52        error_index: int,
53        error_pid: int,
54    ):
55        super().__init__(msg, error_index, error_pid)
56
57
58class ProcessExitedException(ProcessException):
59    """Exception raised when a process failed due to signal or exited with a specific code."""
60
61    __slots__ = ["exit_code"]
62
63    def __init__(
64        self,
65        msg: str,
66        error_index: int,
67        error_pid: int,
68        exit_code: int,
69        signal_name: Optional[str] = None,
70    ):
71        super().__init__(msg, error_index, error_pid)
72        self.exit_code = exit_code
73        self.signal_name = signal_name
74
75    def __reduce__(self):
76        return (
77            type(self),
78            (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
79        )
80
81
82def _wrap(fn, i, args, error_file):
83    # prctl(2) is a Linux specific system call.
84    # On other systems the following function call has no effect.
85    # This is set to ensure that non-daemonic child processes can
86    # terminate if their parent terminates before they do.
87    _prctl_pr_set_pdeathsig(signal.SIGINT)
88
89    try:
90        fn(i, *args)
91    except KeyboardInterrupt:
92        pass  # SIGINT; Killed by parent, do nothing
93    except Exception:
94        # Propagate exception to parent process, keeping original traceback
95        import traceback
96
97        with open(error_file, "wb") as fh:
98            pickle.dump(traceback.format_exc(), fh)
99        sys.exit(1)
100
101
102class ProcessContext:
103    def __init__(self, processes, error_files):
104        self.error_files = error_files
105        self.processes = processes
106        self.sentinels = {
107            process.sentinel: index for index, process in enumerate(processes)
108        }
109
110    def pids(self):
111        return [int(process.pid) for process in self.processes]
112
113    def join(self, timeout=None):
114        r"""Join one or more processes within spawn context.
115
116        Attempt to join one or more processes in this spawn context.
117        If one of them exited with a non-zero exit status, this function
118        kills the remaining processes and raises an exception with the cause
119        of the first process exiting.
120
121        Returns ``True`` if all processes have been joined successfully,
122        ``False`` if there are more processes that need to be joined.
123
124        Args:
125            timeout (float): Wait this long before giving up on waiting.
126        """
127        # Ensure this function can be called even when we're done.
128        if len(self.sentinels) == 0:
129            return True
130
131        # Wait for any process to fail or all of them to succeed.
132        ready = multiprocessing.connection.wait(
133            self.sentinels.keys(),
134            timeout=timeout,
135        )
136
137        error_index = None
138        for sentinel in ready:
139            index = self.sentinels.pop(sentinel)
140            process = self.processes[index]
141            process.join()
142            if process.exitcode != 0:
143                error_index = index
144                break
145
146        # Return if there was no error.
147        if error_index is None:
148            # Return whether or not all processes have been joined.
149            return len(self.sentinels) == 0
150
151        # Assume failure. Terminate processes that are still alive.
152        # Try SIGTERM then SIGKILL if the process isn't going down.
153        # The reason is related to python signal handling is limited
154        # to main thread and if that is in c/c++ land and stuck it won't
155        # to handle it. We have seen processes getting stuck not handling
156        # SIGTERM for the above reason.
157        timeout: int = 30
158        for process in self.processes:
159            if process.is_alive():
160                log.warning("Terminating process %s via signal SIGTERM", process.pid)
161                process.terminate()
162        end = time.monotonic() + timeout
163        for process in self.processes:
164            time_to_wait = max(0, end - time.monotonic())
165            process.join(time_to_wait)
166        for process in self.processes:
167            if process.is_alive():
168                log.warning(
169                    "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
170                    process.pid,
171                )
172                process.kill()
173            process.join()
174
175        # The file will only be created if the process crashed.
176        failed_process = self.processes[error_index]
177        if not os.access(self.error_files[error_index], os.R_OK):
178            exitcode = self.processes[error_index].exitcode
179            if exitcode < 0:
180                try:
181                    name = signal.Signals(-exitcode).name
182                except ValueError:
183                    name = f"<Unknown signal {-exitcode}>"
184                raise ProcessExitedException(
185                    "process %d terminated with signal %s" % (error_index, name),
186                    error_index=error_index,
187                    error_pid=failed_process.pid,
188                    exit_code=exitcode,
189                    signal_name=name,
190                )
191            else:
192                raise ProcessExitedException(
193                    "process %d terminated with exit code %d" % (error_index, exitcode),
194                    error_index=error_index,
195                    error_pid=failed_process.pid,
196                    exit_code=exitcode,
197                )
198
199        with open(self.error_files[error_index], "rb") as fh:
200            original_trace = pickle.load(fh)
201        msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
202        msg += original_trace
203        raise ProcessRaisedException(msg, error_index, failed_process.pid)
204
205
206class SpawnContext(ProcessContext):
207    def __init__(self, processes, error_files):
208        warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
209        super().__init__(processes, error_files)
210
211
212# Note: [start_processes]
213# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
214# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
215# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
216# works better than 'spawn'. Every helper function we created for mp.spawn is indeed
217# general enough, and backends like XLA can reuse them in Colab notebooks as well.
218# Currently we only add this API first, we can consider adding it to documentation as
219# needed in the future.
220def start_processes(
221    fn,
222    args=(),
223    nprocs=1,
224    join=True,
225    daemon=False,
226    start_method="spawn",
227):
228    # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
229    # this func will start processes in parallel if start_method is 'forkserver'.
230    # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
231    # todo: investigate why spawn does not work with threadpool and raises SIGINT
232    if (
233        start_method == "forkserver"
234        and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
235    ):
236        log.info("Starting processes in parallel.")
237        start_parallel = True
238    else:
239        # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
240        start_parallel = False
241
242    mp = multiprocessing.get_context(start_method)
243    error_files = [None] * nprocs
244    processes = [None] * nprocs
245
246    def start_process(i):
247        # Each process is assigned a file to write tracebacks to.  We
248        # use the file being non-empty to indicate an exception
249        # occurred (vs an expected shutdown).  Note: this previously
250        # used a multiprocessing.Queue but that can be prone to
251        # deadlocks, so we went with a simpler solution for a one-shot
252        # message between processes.
253        tf = tempfile.NamedTemporaryFile(
254            prefix="pytorch-errorfile-", suffix=".pickle", delete=False
255        )
256        tf.close()
257        os.unlink(tf.name)
258        process = mp.Process(
259            target=_wrap,
260            args=(fn, i, args, tf.name),
261            daemon=daemon,
262        )
263        process.start()
264        return i, process, tf.name
265
266    if not start_parallel:
267        for i in range(nprocs):
268            idx, process, tf_name = start_process(i)
269            error_files[idx] = tf_name
270            processes[idx] = process
271    else:
272        with ThreadPoolExecutor(max_workers=nprocs) as executor:
273            futures = [executor.submit(start_process, i) for i in range(nprocs)]
274            for fut in as_completed(futures):
275                idx, process, tf_name = fut.result()
276                # idx and process rank needs to be the same.
277                error_files[idx] = tf_name
278                processes[idx] = process
279    context = ProcessContext(processes, error_files)
280    if not join:
281        return context
282
283    # Loop on join until it returns True or raises an exception.
284    while not context.join():
285        pass
286
287
288def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
289    r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
290
291    If one of the processes exits with a non-zero exit status, the
292    remaining processes are killed and an exception is raised with the
293    cause of termination. In the case an exception was caught in the
294    child process, it is forwarded and its traceback is included in
295    the exception raised in the parent process.
296
297    Args:
298        fn (function): Function is called as the entrypoint of the
299            spawned process. This function must be defined at the top
300            level of a module so it can be pickled and spawned. This
301            is a requirement imposed by multiprocessing.
302
303            The function is called as ``fn(i, *args)``, where ``i`` is
304            the process index and ``args`` is the passed through tuple
305            of arguments.
306
307        args (tuple): Arguments passed to ``fn``.
308        nprocs (int): Number of processes to spawn.
309        join (bool): Perform a blocking join on all processes.
310        daemon (bool): The spawned processes' daemon flag. If set to True,
311                       daemonic processes will be created.
312        start_method (str): (deprecated) this method will always use ``spawn``
313                               as the start method. To use a different start method
314                               use ``start_processes()``.
315
316    Returns:
317        None if ``join`` is ``True``,
318        :class:`~ProcessContext` if ``join`` is ``False``
319
320    """
321    if start_method != "spawn":
322        msg = (
323            f"This method only supports start_method=spawn (got: {start_method}).\n"
324            "To use a different start_method use:\n\t\t"
325            " torch.multiprocessing.start_processes(...)"
326        )
327        warnings.warn(msg, FutureWarning, stacklevel=2)
328    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
329