xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/multiprocessing/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import abc
11import logging
12import os
13import re
14import shutil
15import signal
16import subprocess
17import sys
18import tempfile
19import threading
20import time
21from abc import ABC, abstractmethod
22from contextlib import nullcontext
23from dataclasses import dataclass, field
24from enum import IntFlag
25from multiprocessing import synchronize
26from types import FrameType
27from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
28
29import torch.multiprocessing as mp
30from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
31from torch.distributed.elastic.multiprocessing.redirects import (
32    redirect_stderr,
33    redirect_stdout,
34)
35from torch.distributed.elastic.multiprocessing.subprocess_handler import (
36    get_subprocess_handler,
37    SubprocessHandler,
38)
39from torch.distributed.elastic.multiprocessing.tail_log import TailLog
40
41
42IS_WINDOWS = sys.platform == "win32"
43IS_MACOS = sys.platform == "darwin"
44
45
46logger = logging.getLogger(__name__)
47
48__all__ = [
49    "DefaultLogsSpecs",
50    "SignalException",
51    "Std",
52    "to_map",
53    "RunProcsResult",
54    "PContext",
55    "get_std_cm",
56    "MultiprocessContext",
57    "SubprocessContext",
58    "LogsDest",
59    "LogsSpecs",
60]
61
62
63class SignalException(Exception):
64    """
65    Exception is raised inside the torchelastic agent process by the termination handler
66    if the death signal got received by the process.
67    """
68
69    def __init__(self, msg: str, sigval: signal.Signals) -> None:
70        super().__init__(msg)
71        self.sigval = sigval
72
73
74def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
75    """Termination handler that raises exceptions on the main process.
76
77    When the process receives death signal(SIGTERM, SIGINT), this termination handler will
78    be invoked. It raises the ``SignalException`` exception that should be processed by the
79    user code. Python does not terminate process after the termination handler is finished,
80    so the exception should not be silently ignored, otherwise the process will never
81    be terminated.
82    """
83    sigval = signal.Signals(signum)
84    raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
85
86
87def _get_kill_signal() -> signal.Signals:
88    """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows."""
89    if IS_WINDOWS:
90        return signal.CTRL_C_EVENT  # type: ignore[attr-defined] # noqa: F821
91    else:
92        return signal.SIGKILL
93
94
95def _get_default_signal() -> signal.Signals:
96    """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
97    if IS_WINDOWS:
98        return signal.CTRL_C_EVENT  # type: ignore[attr-defined] # noqa: F821
99    else:
100        return signal.SIGTERM
101
102
103def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
104    actual_keys = set(d.keys())
105    expected_keys = set(range(nprocs))
106
107    if actual_keys != expected_keys:
108        raise RuntimeError(
109            f"{what}, local rank mapping mismatch,"
110            f" expected: {expected_keys}, actual: {actual_keys}"
111        )
112
113
114_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$"
115_VALUE_REGEX = r"^[0123]$"
116
117
118class Std(IntFlag):
119    NONE = 0
120    OUT = 1
121    ERR = 2
122    ALL = OUT | ERR
123
124    @classmethod
125    def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
126        """
127        Example:
128        ::
129
130         from_str("0") -> Std.NONE
131         from_str("1") -> Std.OUT
132         from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
133
134        Any other input raises an exception
135        """
136
137        def to_std(v: str) -> Std:  # type: ignore[return]
138            s = Std(int(v))
139            if s in Std:
140                return s
141            # return None -> should NEVER reach here since we regex check input
142
143        if re.match(_VALUE_REGEX, vm):  # vm is a number (e.g. 0)
144            return to_std(vm)
145        elif re.match(_MAPPING_REGEX, vm):  # vm is a mapping (e.g. 0:1,1:2)
146            d: Dict[int, Std] = {}
147            for m in vm.split(","):
148                i, v = m.split(":")
149                d[int(i)] = to_std(v)
150            return d
151        else:
152            raise ValueError(
153                f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>"
154            )
155
156
157def to_map(
158    val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
159) -> Dict[int, Std]:
160    """
161    Certain APIs take redirect settings either as a single value (e.g. apply to all
162    local ranks) or as an explicit user-provided mapping. This method is a convenience
163    method that converts a value or mapping into a mapping.
164
165    Example:
166    ::
167
168     to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
169     to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
170     to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
171    """
172    if isinstance(val_or_map, Std):
173        return dict.fromkeys(range(local_world_size), val_or_map)
174    else:
175        map = {}
176        for i in range(local_world_size):
177            map[i] = val_or_map.get(i, Std.NONE)
178        return map
179
180
181@dataclass
182class LogsDest:
183    """
184    For each log type, holds mapping of local rank ids to file paths.
185    """
186
187    stdouts: Dict[int, str] = field(default_factory=dict)
188    stderrs: Dict[int, str] = field(default_factory=dict)
189    tee_stdouts: Dict[int, str] = field(default_factory=dict)
190    tee_stderrs: Dict[int, str] = field(default_factory=dict)
191    error_files: Dict[int, str] = field(default_factory=dict)
192
193
194class LogsSpecs(ABC):
195    """
196    Defines logs processing and redirection for each worker process.
197
198    Args:
199        log_dir:
200            Base directory where logs will be written.
201        redirects:
202            Streams to redirect to files. Pass a single ``Std``
203            enum to redirect for all workers, or a mapping keyed
204            by local_rank to selectively redirect.
205        tee:
206            Streams to duplicate to stdout/stderr.
207            Pass a single ``Std`` enum to duplicate streams for all workers,
208            or a mapping keyed by local_rank to selectively duplicate.
209    """
210
211    def __init__(
212        self,
213        log_dir: Optional[str] = None,
214        redirects: Union[Std, Dict[int, Std]] = Std.NONE,
215        tee: Union[Std, Dict[int, Std]] = Std.NONE,
216        local_ranks_filter: Optional[Set[int]] = None,
217    ) -> None:
218        self._root_log_dir = log_dir
219        self._redirects = redirects
220        self._tee = tee
221        self._local_ranks_filter = local_ranks_filter
222
223    @abstractmethod
224    def reify(
225        self,
226        envs: Dict[int, Dict[str, str]],
227    ) -> LogsDest:
228        """
229        Given the environment variables, builds destination of log files for each of the local ranks.
230
231        Envs parameter contains env variables dict for each of the local ranks, where entries are defined in:
232        :func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`.
233        """
234
235    @property
236    @abstractmethod
237    def root_log_dir(self) -> str:
238        pass
239
240
241class DefaultLogsSpecs(LogsSpecs):
242    """
243    Default LogsSpecs implementation:
244
245    - `log_dir` will be created if it doesn't exist
246    - Generates nested folders for each attempt and rank.
247    """
248
249    def __init__(
250        self,
251        log_dir: Optional[str] = None,
252        redirects: Union[Std, Dict[int, Std]] = Std.NONE,
253        tee: Union[Std, Dict[int, Std]] = Std.NONE,
254        local_ranks_filter: Optional[Set[int]] = None,
255    ) -> None:
256        if log_dir != os.devnull:
257            if not log_dir:
258                log_dir = tempfile.mkdtemp(prefix="torchelastic_")
259            elif not os.path.exists(log_dir):
260                os.makedirs(log_dir, exist_ok=True)
261            else:
262                if os.path.isfile(log_dir):
263                    raise NotADirectoryError(f"log_dir: {log_dir} is a file")
264        super().__init__(log_dir, redirects, tee, local_ranks_filter)
265        # initialized only once
266        self._run_log_dir = None
267
268    @property
269    def root_log_dir(self) -> str:
270        return str(self._root_log_dir)
271
272    def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str):
273        base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
274        os.makedirs(base_log_dir, exist_ok=True)
275        dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)
276        logger.info("log directory set to: %s", dir)
277        return dir
278
279    def reify(
280        self,
281        envs: Dict[int, Dict[str, str]],
282    ) -> LogsDest:
283        """
284        Uses following scheme to build log destination paths:
285
286        - `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stdout.log`
287        - `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stderr.log`
288        - `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/error.json`
289        """
290        nprocs = len(envs)
291        global_env = {}  # use only to query properies that are not dependent on a rank
292        if nprocs > 0:
293            global_env = envs[0]
294        else:
295            logger.warning(
296                "Empty envs map provided when defining logging destinations."
297            )
298        # Keys are always defined, but values can be missing in unit tests
299        run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id")
300        restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0")
301
302        attempt_log_dir: str = ""
303        if self._root_log_dir != os.devnull:
304            if not self._run_log_dir:
305                self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
306
307            attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}")  # type: ignore[call-overload]
308            shutil.rmtree(attempt_log_dir, ignore_errors=True)
309            os.makedirs(attempt_log_dir)
310
311        if self._root_log_dir == os.devnull:
312            attempt_log_dir = os.devnull
313
314        # create subdirs for each local rank in the logs_dir
315        # logs_dir
316        #       |- 0
317        #          |- error.json
318        #          |- stdout.log
319        #          |- stderr.log
320        #       |- ...
321        #       |- (nprocs-1)
322        redirs = to_map(self._redirects, nprocs)
323        ts = to_map(self._tee, nprocs)
324
325        # to tee stdout/stderr we first redirect into a file
326        # then tail -f stdout.log/stderr.log so add tee settings to redirects
327        for local_rank, tee_std in ts.items():
328            redirect_std = redirs[local_rank]
329            redirs[local_rank] = redirect_std | tee_std
330
331        SYS_STREAM = ""  # special case to indicate to output to console
332        stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
333        stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
334        tee_stdouts: Dict[int, str] = {}
335        tee_stderrs: Dict[int, str] = {}
336        error_files = {}
337
338        for local_rank in range(nprocs):
339            if attempt_log_dir == os.devnull:
340                tee_stdouts[local_rank] = os.devnull
341                tee_stderrs[local_rank] = os.devnull
342                error_files[local_rank] = os.devnull
343                envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = ""
344            else:
345                clogdir = os.path.join(attempt_log_dir, str(local_rank))
346                os.mkdir(clogdir)
347
348                rd = redirs[local_rank]
349                if (rd & Std.OUT) == Std.OUT:
350                    stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
351                if (rd & Std.ERR) == Std.ERR:
352                    stderrs[local_rank] = os.path.join(clogdir, "stderr.log")
353
354                t = ts[local_rank]
355                if t & Std.OUT == Std.OUT:
356                    tee_stdouts[local_rank] = stdouts[local_rank]
357                if t & Std.ERR == Std.ERR:
358                    tee_stderrs[local_rank] = stderrs[local_rank]
359
360                if (
361                    self._local_ranks_filter
362                    and local_rank not in self._local_ranks_filter
363                ):
364                    # If stream is tee'd, only write to file, but don't tail
365                    if local_rank in tee_stdouts:
366                        tee_stdouts.pop(local_rank, None)
367                    if local_rank in tee_stderrs:
368                        tee_stderrs.pop(local_rank, None)
369
370                    # If stream is not redirected, don't print
371                    if stdouts[local_rank] == SYS_STREAM:
372                        stdouts[local_rank] = os.devnull
373                    if stderrs[local_rank] == SYS_STREAM:
374                        stderrs[local_rank] = os.devnull
375
376                error_file = os.path.join(clogdir, "error.json")
377                error_files[local_rank] = error_file
378                logger.info(
379                    "Setting worker%s reply file to: %s", local_rank, error_file
380                )
381                envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file
382
383        return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files)
384
385    def __repr__(self) -> str:
386        return (
387            f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, "
388            f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})"
389        )
390
391    def __eq__(self, other: object) -> bool:
392        if not isinstance(other, DefaultLogsSpecs):
393            return False
394
395        return (
396            self._root_log_dir == other._root_log_dir
397            and self._redirects == other._redirects
398            and self._tee == other._tee
399            and self._local_ranks_filter == other._local_ranks_filter
400        )
401
402
403@dataclass
404class RunProcsResult:
405    """
406    Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``.
407
408    Note the following:
409
410    1. All fields are mapped by local rank
411    2. ``return_values`` - only populated for functions (not the binaries).
412    3. ``stdouts`` - path to stdout.log (empty string if no redirect)
413    4. ``stderrs`` - path to stderr.log (empty string if no redirect)
414
415    """
416
417    return_values: Dict[int, Any] = field(default_factory=dict)
418    failures: Dict[int, ProcessFailure] = field(default_factory=dict)
419    stdouts: Dict[int, str] = field(default_factory=dict)
420    stderrs: Dict[int, str] = field(default_factory=dict)
421
422    def is_failed(self) -> bool:
423        return len(self.failures) > 0
424
425
426class PContext(abc.ABC):
427    """
428    The base class that standardizes operations over a set of processes that are launched via different mechanisms.
429
430    The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.
431
432    .. warning:: stdouts and stderrs should ALWAYS be a superset of
433                 tee_stdouts and tee_stderrs (respectively) this is b/c
434                 tee is implemented as a redirect + tail -f <stdout/stderr.log>
435    """
436
437    def __init__(
438        self,
439        name: str,
440        entrypoint: Union[Callable, str],
441        args: Dict[int, Tuple],
442        envs: Dict[int, Dict[str, str]],
443        logs_specs: LogsSpecs,
444        log_line_prefixes: Optional[Dict[int, str]] = None,
445    ):
446        self.name = name
447        # validate that all mappings have the same number of keys and
448        # all local ranks are accounted for
449        nprocs = len(args)
450
451        # TODO log_line_prefixes can be exanded too
452        logs_dest = logs_specs.reify(envs)
453
454        _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts")
455        _validate_full_rank(logs_dest.stderrs, nprocs, "stderrs")
456
457        self.entrypoint = entrypoint
458        self.args = args
459        self.envs = envs
460        self.stdouts = logs_dest.stdouts
461        self.stderrs = logs_dest.stderrs
462        self.error_files = logs_dest.error_files
463        self.nprocs = nprocs
464
465        self._stdout_tail = TailLog(
466            name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes
467        )
468        self._stderr_tail = TailLog(
469            name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes
470        )
471
472    def start(self) -> None:
473        """Start processes using parameters defined in the constructor."""
474        if threading.current_thread() is threading.main_thread():
475            signal.signal(signal.SIGTERM, _terminate_process_handler)
476            signal.signal(signal.SIGINT, _terminate_process_handler)
477            if not IS_WINDOWS:
478                signal.signal(signal.SIGHUP, _terminate_process_handler)
479                signal.signal(signal.SIGQUIT, _terminate_process_handler)
480        else:
481            logger.warning(
482                "Failed to register signal handlers since torchelastic is running on a child thread. "
483                "This could lead to orphaned worker processes if the torchrun is terminated."
484            )
485        self._start()
486        self._stdout_tail.start()
487        self._stderr_tail.start()
488
489    @abc.abstractmethod
490    def _start(self) -> None:
491        """Start processes using strategy defined in a particular context."""
492        raise NotImplementedError
493
494    @abc.abstractmethod
495    def _poll(self) -> Optional[RunProcsResult]:
496        """
497        Poll the run status of the processes running under this context.
498        This method follows an "all-or-nothing" policy and returns
499        a ``RunProcessResults`` object if either all processes complete
500        successfully or any process fails. Returns ``None`` if
501        all processes are still running.
502        """
503        raise NotImplementedError
504
505    def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
506        """
507        Wait for the specified ``timeout`` seconds, polling every ``period`` seconds
508        for the processes to be done. Returns ``None`` if the processes are still running
509        on timeout expiry. Negative timeout values are interpreted as "wait-forever".
510        A timeout value of zero simply queries the status of the processes (e.g. equivalent
511        to a poll).
512
513        ..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise
514                ``SignalException`` when the signals received. It is up to the consumer of the code
515                to properly handle the exception. It is important not to swallow the exception otherwise
516                the process would not terminate. Example of the typical workflow can be:
517
518        .. code-block:: python
519            pc = start_processes(...)
520            try:
521                pc.wait(1)
522                .. do some other work
523            except SignalException as e:
524                pc.shutdown(e.sigval, timeout=30)
525
526        If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
527        received signal. If child processes will not terminate in the timeout time, the process will send
528        the SIGKILL.
529        """
530        if timeout == 0:
531            return self._poll()
532
533        if timeout < 0:
534            timeout = sys.maxsize
535
536        expiry = time.time() + timeout
537        while time.time() < expiry:
538            pr = self._poll()
539            if pr:
540                return pr
541            time.sleep(period)
542
543        return None
544
545    @abc.abstractmethod
546    def pids(self) -> Dict[int, int]:
547        """Return pids of processes mapped by their respective local_ranks."""
548        raise NotImplementedError
549
550    @abc.abstractmethod
551    def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
552        r"""
553        Terminates all processes managed by this context and cleans up any
554        meta resources (e.g. redirect, error_file files).
555        """
556        raise NotImplementedError
557
558    def close(
559        self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
560    ) -> None:
561        r"""
562        Terminates all processes managed by this context and cleans up any
563        meta resources (e.g. redirect, error_file files).
564
565        Args:
566            death_sig: Death signal to terminate processes.
567            timeout: Time to wait for processes to finish, if process is
568                still alive after this time, it will be terminated via SIGKILL.
569        """
570        if not death_sig:
571            death_sig = _get_default_signal()
572        self._close(death_sig=death_sig, timeout=timeout)
573        if self._stdout_tail:
574            self._stdout_tail.stop()
575        if self._stderr_tail:
576            self._stderr_tail.stop()
577
578
579def get_std_cm(std_rd: str, redirect_fn):
580    if IS_WINDOWS or IS_MACOS or not std_rd:
581        return nullcontext()
582    else:
583        return redirect_fn(std_rd)
584
585
586def _wrap(
587    local_rank: int,
588    fn: Callable,
589    args: Dict[int, Tuple],
590    envs: Dict[int, Dict[str, str]],
591    stdout_redirects: Dict[int, str],  # redirect file for stdout (to console if None)
592    stderr_redirects: Dict[int, str],  # redirect file for stderr (to console if None)
593    ret_vals: Dict[int, mp.SimpleQueue],
594    queue_finished_reading_event: synchronize.Event,
595) -> None:
596    # get the per-rank params up front so we fail fast if no mapping is found
597    args_ = args[local_rank]
598    env_ = envs[local_rank]
599    ret_val_ = ret_vals[local_rank]
600
601    stdout_rd = stdout_redirects[local_rank]
602    stderr_rd = stderr_redirects[local_rank]
603
604    stdout_cm = get_std_cm(stdout_rd, redirect_stdout)
605    stderr_cm = get_std_cm(stderr_rd, redirect_stderr)
606
607    for k, v in env_.items():
608        os.environ[k] = v
609
610    with stdout_cm, stderr_cm:
611        ret = record(fn)(*args_)
612    ret_val_.put(ret)
613    queue_finished_reading_event.wait()
614
615
616class MultiprocessContext(PContext):
617    """``PContext`` holding worker processes invoked as a function."""
618
619    def __init__(
620        self,
621        name: str,
622        entrypoint: Callable,
623        args: Dict[int, Tuple],
624        envs: Dict[int, Dict[str, str]],
625        start_method: str,
626        logs_specs: LogsSpecs,
627        log_line_prefixes: Optional[Dict[int, str]] = None,
628    ):
629        super().__init__(
630            name,
631            entrypoint,
632            args,
633            envs,
634            logs_specs,
635            log_line_prefixes,
636        )
637
638        self.start_method = start_method
639        # each ret_val queue will always contain a single element.
640        self._ret_vals = {
641            local_rank: mp.get_context(self.start_method).SimpleQueue()
642            for local_rank in range(self.nprocs)
643        }
644
645        # see comments in ``join()`` for what this is
646        self._return_values: Dict[int, Any] = {}
647        self._pc: Optional[mp.ProcessContext] = None
648        # Note: set method should ONLY be invoked for the use case when all processes finished
649        # successfully. If any process died on event.wait() calling set() method will deadlock.
650        self._worker_finished_event = mp.get_context(self.start_method).Event()
651
652    def _start(self):
653        if self._pc:
654            raise ValueError(
655                "The process context already initialized."
656                " Most likely the start method got called twice."
657            )
658        self._pc = mp.start_processes(
659            fn=_wrap,
660            args=(
661                self.entrypoint,
662                self.args,
663                self.envs,
664                self.stdouts,
665                self.stderrs,
666                self._ret_vals,
667                self._worker_finished_event,
668            ),
669            nprocs=self.nprocs,
670            join=False,
671            daemon=False,
672            start_method=self.start_method,
673        )
674
675    def _is_done(self) -> bool:
676        return len(self._return_values) == self.nprocs
677
678    def _poll(self) -> Optional[RunProcsResult]:
679        assert self._pc is not None  # assertion for mypy type checker
680
681        try:
682            # torch.mp.ProcessContext Throws an Exception if some/all of
683            # worker processes failed
684            # timeout < 0 checks worker status and return immediately
685            # Join will never return success since we use synchronize.Event to wait
686            # for all processes to finish.
687            self._pc.join(-1)
688
689            # IMPORTANT: we use multiprocessing.Queue to carry worker return values
690            # back to the parent, the worker process will wait before terminating
691            # until all the buffered items are fed by the feeder thread to the underlying
692            # pipe. Hence to prevent deadlocks on large return values,
693            # we opportunistically try queue.get on each join call
694            # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
695            for local_rank in range(0, self.nprocs):
696                return_queue = self._ret_vals[local_rank]
697                if not return_queue.empty():
698                    # save the return values temporarily into a member var
699                    self._return_values[local_rank] = return_queue.get()
700
701            if self._is_done():
702                # we should ALWAYS have ALL the return values when all the processes are done
703                self._worker_finished_event.set()
704
705                # At this point workers finished running the user function
706                # But the child process might still have not exited. Wait for them.
707                # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits.
708                while not self._pc.join():
709                    logger.debug(
710                        "entrypoint fn finished, waiting for all child procs to exit..."
711                    )
712
713                _validate_full_rank(
714                    self._return_values, self.nprocs, "return_value queue"
715                )
716                self.close()
717                return RunProcsResult(
718                    return_values=self._return_values,
719                    stdouts=self.stdouts,
720                    stderrs=self.stderrs,
721                )
722            else:
723                return None
724        except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
725            failed_local_rank = e.error_index
726
727            # entrypoint for MultiprocessContext will always be a Callable
728            fn_name = self.entrypoint.__qualname__  # type: ignore[union-attr]
729            failed_proc = self._pc.processes[failed_local_rank]
730            error_filepath = self.error_files[failed_local_rank]
731
732            logger.exception(
733                "failed (exitcode: %s)"
734                " local_rank: %s (pid: %s)"
735                " of fn: %s (start_method: %s)",
736                failed_proc.exitcode,
737                failed_local_rank,
738                e.pid,
739                fn_name,
740                self.start_method,
741            )
742
743            self.close()
744            return RunProcsResult(
745                failures={
746                    failed_local_rank: ProcessFailure(
747                        local_rank=failed_local_rank,
748                        pid=e.pid,
749                        exitcode=failed_proc.exitcode,
750                        error_file=error_filepath,
751                    )
752                },
753                stdouts=self.stdouts,
754                stderrs=self.stderrs,
755            )
756
757    def pids(self) -> Dict[int, int]:
758        assert self._pc is not None  # assertion for mypy type checking
759        return dict(enumerate(self._pc.pids()))
760
761    def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
762        if not self._pc:
763            return
764        for proc in self._pc.processes:
765            if proc.is_alive():
766                logger.warning(
767                    "Closing process %s via signal %s", proc.pid, death_sig.name
768                )
769                try:
770                    os.kill(proc.pid, death_sig)
771                except ProcessLookupError:
772                    # If the process exited because of some reason,
773                    # `ProcessLookupError` will be raised, it is safe to ignore it.
774                    pass
775        end = time.monotonic() + timeout
776        for proc in self._pc.processes:
777            time_to_wait = end - time.monotonic()
778            if time_to_wait <= 0:
779                break
780            proc.join(time_to_wait)
781        for proc in self._pc.processes:
782            if proc.is_alive():
783                logger.warning(
784                    "Unable to shutdown process %s via %s, forcefully exiting via %s",
785                    proc.pid,
786                    death_sig,
787                    _get_kill_signal(),
788                )
789                try:
790                    os.kill(proc.pid, _get_kill_signal())
791                except ProcessLookupError:
792                    # If the process exited because of some reason,
793                    # `ProcessLookupError` will be raised, it is safe to ignore it.
794                    pass
795            proc.join()
796
797
798class SubprocessContext(PContext):
799    """``PContext`` holding worker processes invoked as a binary."""
800
801    def __init__(
802        self,
803        name: str,
804        entrypoint: str,
805        args: Dict[int, Tuple],
806        envs: Dict[int, Dict[str, str]],
807        logs_specs: LogsSpecs,
808        log_line_prefixes: Optional[Dict[int, str]] = None,
809    ):
810        super().__init__(
811            name,
812            entrypoint,
813            args,
814            envs,
815            logs_specs,
816            log_line_prefixes,
817        )
818
819        # state vector; _vdone[local_rank] -> is local_rank finished or not
820        self._running_local_ranks: Set[int] = set(range(self.nprocs))
821        self._failures: Dict[int, ProcessFailure] = {}
822        self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
823
824    def _start(self):
825        if self.subprocess_handlers:
826            raise ValueError(
827                "The subprocess handlers already initialized. Most likely the start method got called twice."
828            )
829        self.subprocess_handlers = {
830            local_rank: get_subprocess_handler(
831                entrypoint=self.entrypoint,  # type: ignore[arg-type] # entrypoint is always a str
832                args=self.args[local_rank],
833                env=self.envs[local_rank],
834                stdout=self.stdouts[local_rank],
835                stderr=self.stderrs[local_rank],
836                local_rank_id=local_rank,
837            )
838            for local_rank in range(self.nprocs)
839        }
840
841    def _poll(self) -> Optional[RunProcsResult]:
842        done_local_ranks = set()
843        for local_rank in self._running_local_ranks:
844            handler = self.subprocess_handlers[local_rank]
845            exitcode = handler.proc.poll()
846            if exitcode is not None:
847                done_local_ranks.add(local_rank)
848                if exitcode != 0:  # failed or signaled
849                    self._failures[local_rank] = ProcessFailure(
850                        local_rank=local_rank,
851                        pid=handler.proc.pid,
852                        exitcode=exitcode,
853                        error_file=self.error_files[local_rank],
854                    )
855                # else: --> succeeded; nothing to do
856
857        self._running_local_ranks.difference_update(done_local_ranks)
858
859        # if ALL procs are finished or ANY have failed
860        if not self._running_local_ranks or self._failures:
861            self.close()  # terminate all running procs
862            result = RunProcsResult(
863                failures=self._failures,
864                stdouts=self.stdouts,
865                stderrs=self.stderrs,
866            )
867            if result.is_failed():
868                first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
869                logger.error(
870                    "failed (exitcode: %s)"
871                    " local_rank: %s (pid: %s)"
872                    " of binary: %s",
873                    first_failure.exitcode,
874                    first_failure.local_rank,
875                    first_failure.pid,
876                    self.entrypoint,
877                )
878            else:
879                # Populate return with dummy values. This provides consistency with MultiprocessingHandler
880                result.return_values = dict.fromkeys(range(self.nprocs))
881
882            return result
883        else:  # there are no failures and procs still running
884            return None
885
886    def pids(self) -> Dict[int, int]:
887        return {
888            local_rank: sh.proc.pid
889            for local_rank, sh in self.subprocess_handlers.items()
890        }
891
892    def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
893        if not self.subprocess_handlers:
894            return
895        for handler in self.subprocess_handlers.values():
896            if handler.proc.poll() is None:
897                logger.warning(
898                    "Sending process %s closing signal %s",
899                    handler.proc.pid,
900                    death_sig.name,
901                )
902                handler.close(death_sig=death_sig)
903        end = time.monotonic() + timeout
904        for handler in self.subprocess_handlers.values():
905            time_to_wait = end - time.monotonic()
906            if time_to_wait <= 0:
907                break
908            try:
909                handler.proc.wait(time_to_wait)
910            except subprocess.TimeoutExpired:
911                # Ignore the timeout expired exception, since
912                # the child process will be forcefully terminated via SIGKILL
913                pass
914        for handler in self.subprocess_handlers.values():
915            if handler.proc.poll() is None:
916                logger.warning(
917                    "Unable to shutdown process %s via %s, forcefully exiting via %s",
918                    handler.proc.pid,
919                    death_sig,
920                    _get_kill_signal(),
921                )
922                handler.close(death_sig=_get_kill_signal())
923                handler.proc.wait()
924