1# mypy: allow-untyped-defs 2import logging 3import multiprocessing 4import multiprocessing.connection 5import os 6import pickle 7import signal 8import sys 9import tempfile 10import time 11import warnings 12from concurrent.futures import as_completed, ThreadPoolExecutor 13from typing import Optional 14 15from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] 16 17 18ENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START" 19 20log = logging.getLogger(__name__) 21 22__all__ = [ 23 "ProcessContext", 24 "ProcessException", 25 "ProcessExitedException", 26 "ProcessRaisedException", 27 "spawn", 28 "SpawnContext", 29 "start_processes", 30] 31 32 33class ProcessException(Exception): 34 __slots__ = ["error_index", "error_pid"] 35 36 def __init__(self, msg: str, error_index: int, pid: int): 37 super().__init__(msg) 38 self.msg = msg 39 self.error_index = error_index 40 self.pid = pid 41 42 def __reduce__(self): 43 return type(self), (self.msg, self.error_index, self.pid) 44 45 46class ProcessRaisedException(ProcessException): 47 """Exception raised when a process failed due to an exception raised by the code.""" 48 49 def __init__( 50 self, 51 msg: str, 52 error_index: int, 53 error_pid: int, 54 ): 55 super().__init__(msg, error_index, error_pid) 56 57 58class ProcessExitedException(ProcessException): 59 """Exception raised when a process failed due to signal or exited with a specific code.""" 60 61 __slots__ = ["exit_code"] 62 63 def __init__( 64 self, 65 msg: str, 66 error_index: int, 67 error_pid: int, 68 exit_code: int, 69 signal_name: Optional[str] = None, 70 ): 71 super().__init__(msg, error_index, error_pid) 72 self.exit_code = exit_code 73 self.signal_name = signal_name 74 75 def __reduce__(self): 76 return ( 77 type(self), 78 (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name), 79 ) 80 81 82def _wrap(fn, i, args, error_file): 83 # prctl(2) is a Linux specific system call. 84 # On other systems the following function call has no effect. 85 # This is set to ensure that non-daemonic child processes can 86 # terminate if their parent terminates before they do. 87 _prctl_pr_set_pdeathsig(signal.SIGINT) 88 89 try: 90 fn(i, *args) 91 except KeyboardInterrupt: 92 pass # SIGINT; Killed by parent, do nothing 93 except Exception: 94 # Propagate exception to parent process, keeping original traceback 95 import traceback 96 97 with open(error_file, "wb") as fh: 98 pickle.dump(traceback.format_exc(), fh) 99 sys.exit(1) 100 101 102class ProcessContext: 103 def __init__(self, processes, error_files): 104 self.error_files = error_files 105 self.processes = processes 106 self.sentinels = { 107 process.sentinel: index for index, process in enumerate(processes) 108 } 109 110 def pids(self): 111 return [int(process.pid) for process in self.processes] 112 113 def join(self, timeout=None): 114 r"""Join one or more processes within spawn context. 115 116 Attempt to join one or more processes in this spawn context. 117 If one of them exited with a non-zero exit status, this function 118 kills the remaining processes and raises an exception with the cause 119 of the first process exiting. 120 121 Returns ``True`` if all processes have been joined successfully, 122 ``False`` if there are more processes that need to be joined. 123 124 Args: 125 timeout (float): Wait this long before giving up on waiting. 126 """ 127 # Ensure this function can be called even when we're done. 128 if len(self.sentinels) == 0: 129 return True 130 131 # Wait for any process to fail or all of them to succeed. 132 ready = multiprocessing.connection.wait( 133 self.sentinels.keys(), 134 timeout=timeout, 135 ) 136 137 error_index = None 138 for sentinel in ready: 139 index = self.sentinels.pop(sentinel) 140 process = self.processes[index] 141 process.join() 142 if process.exitcode != 0: 143 error_index = index 144 break 145 146 # Return if there was no error. 147 if error_index is None: 148 # Return whether or not all processes have been joined. 149 return len(self.sentinels) == 0 150 151 # Assume failure. Terminate processes that are still alive. 152 # Try SIGTERM then SIGKILL if the process isn't going down. 153 # The reason is related to python signal handling is limited 154 # to main thread and if that is in c/c++ land and stuck it won't 155 # to handle it. We have seen processes getting stuck not handling 156 # SIGTERM for the above reason. 157 timeout: int = 30 158 for process in self.processes: 159 if process.is_alive(): 160 log.warning("Terminating process %s via signal SIGTERM", process.pid) 161 process.terminate() 162 end = time.monotonic() + timeout 163 for process in self.processes: 164 time_to_wait = max(0, end - time.monotonic()) 165 process.join(time_to_wait) 166 for process in self.processes: 167 if process.is_alive(): 168 log.warning( 169 "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL", 170 process.pid, 171 ) 172 process.kill() 173 process.join() 174 175 # The file will only be created if the process crashed. 176 failed_process = self.processes[error_index] 177 if not os.access(self.error_files[error_index], os.R_OK): 178 exitcode = self.processes[error_index].exitcode 179 if exitcode < 0: 180 try: 181 name = signal.Signals(-exitcode).name 182 except ValueError: 183 name = f"<Unknown signal {-exitcode}>" 184 raise ProcessExitedException( 185 "process %d terminated with signal %s" % (error_index, name), 186 error_index=error_index, 187 error_pid=failed_process.pid, 188 exit_code=exitcode, 189 signal_name=name, 190 ) 191 else: 192 raise ProcessExitedException( 193 "process %d terminated with exit code %d" % (error_index, exitcode), 194 error_index=error_index, 195 error_pid=failed_process.pid, 196 exit_code=exitcode, 197 ) 198 199 with open(self.error_files[error_index], "rb") as fh: 200 original_trace = pickle.load(fh) 201 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index 202 msg += original_trace 203 raise ProcessRaisedException(msg, error_index, failed_process.pid) 204 205 206class SpawnContext(ProcessContext): 207 def __init__(self, processes, error_files): 208 warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") 209 super().__init__(processes, error_files) 210 211 212# Note: [start_processes] 213# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a 214# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the 215# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork' 216# works better than 'spawn'. Every helper function we created for mp.spawn is indeed 217# general enough, and backends like XLA can reuse them in Colab notebooks as well. 218# Currently we only add this API first, we can consider adding it to documentation as 219# needed in the future. 220def start_processes( 221 fn, 222 args=(), 223 nprocs=1, 224 join=True, 225 daemon=False, 226 start_method="spawn", 227): 228 # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010), 229 # this func will start processes in parallel if start_method is 'forkserver'. 230 # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1. 231 # todo: investigate why spawn does not work with threadpool and raises SIGINT 232 if ( 233 start_method == "forkserver" 234 and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1" 235 ): 236 log.info("Starting processes in parallel.") 237 start_parallel = True 238 else: 239 # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start 240 start_parallel = False 241 242 mp = multiprocessing.get_context(start_method) 243 error_files = [None] * nprocs 244 processes = [None] * nprocs 245 246 def start_process(i): 247 # Each process is assigned a file to write tracebacks to. We 248 # use the file being non-empty to indicate an exception 249 # occurred (vs an expected shutdown). Note: this previously 250 # used a multiprocessing.Queue but that can be prone to 251 # deadlocks, so we went with a simpler solution for a one-shot 252 # message between processes. 253 tf = tempfile.NamedTemporaryFile( 254 prefix="pytorch-errorfile-", suffix=".pickle", delete=False 255 ) 256 tf.close() 257 os.unlink(tf.name) 258 process = mp.Process( 259 target=_wrap, 260 args=(fn, i, args, tf.name), 261 daemon=daemon, 262 ) 263 process.start() 264 return i, process, tf.name 265 266 if not start_parallel: 267 for i in range(nprocs): 268 idx, process, tf_name = start_process(i) 269 error_files[idx] = tf_name 270 processes[idx] = process 271 else: 272 with ThreadPoolExecutor(max_workers=nprocs) as executor: 273 futures = [executor.submit(start_process, i) for i in range(nprocs)] 274 for fut in as_completed(futures): 275 idx, process, tf_name = fut.result() 276 # idx and process rank needs to be the same. 277 error_files[idx] = tf_name 278 processes[idx] = process 279 context = ProcessContext(processes, error_files) 280 if not join: 281 return context 282 283 # Loop on join until it returns True or raises an exception. 284 while not context.join(): 285 pass 286 287 288def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): 289 r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. 290 291 If one of the processes exits with a non-zero exit status, the 292 remaining processes are killed and an exception is raised with the 293 cause of termination. In the case an exception was caught in the 294 child process, it is forwarded and its traceback is included in 295 the exception raised in the parent process. 296 297 Args: 298 fn (function): Function is called as the entrypoint of the 299 spawned process. This function must be defined at the top 300 level of a module so it can be pickled and spawned. This 301 is a requirement imposed by multiprocessing. 302 303 The function is called as ``fn(i, *args)``, where ``i`` is 304 the process index and ``args`` is the passed through tuple 305 of arguments. 306 307 args (tuple): Arguments passed to ``fn``. 308 nprocs (int): Number of processes to spawn. 309 join (bool): Perform a blocking join on all processes. 310 daemon (bool): The spawned processes' daemon flag. If set to True, 311 daemonic processes will be created. 312 start_method (str): (deprecated) this method will always use ``spawn`` 313 as the start method. To use a different start method 314 use ``start_processes()``. 315 316 Returns: 317 None if ``join`` is ``True``, 318 :class:`~ProcessContext` if ``join`` is ``False`` 319 320 """ 321 if start_method != "spawn": 322 msg = ( 323 f"This method only supports start_method=spawn (got: {start_method}).\n" 324 "To use a different start_method use:\n\t\t" 325 " torch.multiprocessing.start_processes(...)" 326 ) 327 warnings.warn(msg, FutureWarning, stacklevel=2) 328 return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") 329