1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10import abc 11import logging 12import os 13import re 14import shutil 15import signal 16import subprocess 17import sys 18import tempfile 19import threading 20import time 21from abc import ABC, abstractmethod 22from contextlib import nullcontext 23from dataclasses import dataclass, field 24from enum import IntFlag 25from multiprocessing import synchronize 26from types import FrameType 27from typing import Any, Callable, Dict, Optional, Set, Tuple, Union 28 29import torch.multiprocessing as mp 30from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record 31from torch.distributed.elastic.multiprocessing.redirects import ( 32 redirect_stderr, 33 redirect_stdout, 34) 35from torch.distributed.elastic.multiprocessing.subprocess_handler import ( 36 get_subprocess_handler, 37 SubprocessHandler, 38) 39from torch.distributed.elastic.multiprocessing.tail_log import TailLog 40 41 42IS_WINDOWS = sys.platform == "win32" 43IS_MACOS = sys.platform == "darwin" 44 45 46logger = logging.getLogger(__name__) 47 48__all__ = [ 49 "DefaultLogsSpecs", 50 "SignalException", 51 "Std", 52 "to_map", 53 "RunProcsResult", 54 "PContext", 55 "get_std_cm", 56 "MultiprocessContext", 57 "SubprocessContext", 58 "LogsDest", 59 "LogsSpecs", 60] 61 62 63class SignalException(Exception): 64 """ 65 Exception is raised inside the torchelastic agent process by the termination handler 66 if the death signal got received by the process. 67 """ 68 69 def __init__(self, msg: str, sigval: signal.Signals) -> None: 70 super().__init__(msg) 71 self.sigval = sigval 72 73 74def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None: 75 """Termination handler that raises exceptions on the main process. 76 77 When the process receives death signal(SIGTERM, SIGINT), this termination handler will 78 be invoked. It raises the ``SignalException`` exception that should be processed by the 79 user code. Python does not terminate process after the termination handler is finished, 80 so the exception should not be silently ignored, otherwise the process will never 81 be terminated. 82 """ 83 sigval = signal.Signals(signum) 84 raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) 85 86 87def _get_kill_signal() -> signal.Signals: 88 """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.""" 89 if IS_WINDOWS: 90 return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 91 else: 92 return signal.SIGKILL 93 94 95def _get_default_signal() -> signal.Signals: 96 """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" 97 if IS_WINDOWS: 98 return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 99 else: 100 return signal.SIGTERM 101 102 103def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str): 104 actual_keys = set(d.keys()) 105 expected_keys = set(range(nprocs)) 106 107 if actual_keys != expected_keys: 108 raise RuntimeError( 109 f"{what}, local rank mapping mismatch," 110 f" expected: {expected_keys}, actual: {actual_keys}" 111 ) 112 113 114_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$" 115_VALUE_REGEX = r"^[0123]$" 116 117 118class Std(IntFlag): 119 NONE = 0 120 OUT = 1 121 ERR = 2 122 ALL = OUT | ERR 123 124 @classmethod 125 def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]: 126 """ 127 Example: 128 :: 129 130 from_str("0") -> Std.NONE 131 from_str("1") -> Std.OUT 132 from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR} 133 134 Any other input raises an exception 135 """ 136 137 def to_std(v: str) -> Std: # type: ignore[return] 138 s = Std(int(v)) 139 if s in Std: 140 return s 141 # return None -> should NEVER reach here since we regex check input 142 143 if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0) 144 return to_std(vm) 145 elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2) 146 d: Dict[int, Std] = {} 147 for m in vm.split(","): 148 i, v = m.split(":") 149 d[int(i)] = to_std(v) 150 return d 151 else: 152 raise ValueError( 153 f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>" 154 ) 155 156 157def to_map( 158 val_or_map: Union[Std, Dict[int, Std]], local_world_size: int 159) -> Dict[int, Std]: 160 """ 161 Certain APIs take redirect settings either as a single value (e.g. apply to all 162 local ranks) or as an explicit user-provided mapping. This method is a convenience 163 method that converts a value or mapping into a mapping. 164 165 Example: 166 :: 167 168 to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} 169 to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT} 170 to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT} 171 """ 172 if isinstance(val_or_map, Std): 173 return dict.fromkeys(range(local_world_size), val_or_map) 174 else: 175 map = {} 176 for i in range(local_world_size): 177 map[i] = val_or_map.get(i, Std.NONE) 178 return map 179 180 181@dataclass 182class LogsDest: 183 """ 184 For each log type, holds mapping of local rank ids to file paths. 185 """ 186 187 stdouts: Dict[int, str] = field(default_factory=dict) 188 stderrs: Dict[int, str] = field(default_factory=dict) 189 tee_stdouts: Dict[int, str] = field(default_factory=dict) 190 tee_stderrs: Dict[int, str] = field(default_factory=dict) 191 error_files: Dict[int, str] = field(default_factory=dict) 192 193 194class LogsSpecs(ABC): 195 """ 196 Defines logs processing and redirection for each worker process. 197 198 Args: 199 log_dir: 200 Base directory where logs will be written. 201 redirects: 202 Streams to redirect to files. Pass a single ``Std`` 203 enum to redirect for all workers, or a mapping keyed 204 by local_rank to selectively redirect. 205 tee: 206 Streams to duplicate to stdout/stderr. 207 Pass a single ``Std`` enum to duplicate streams for all workers, 208 or a mapping keyed by local_rank to selectively duplicate. 209 """ 210 211 def __init__( 212 self, 213 log_dir: Optional[str] = None, 214 redirects: Union[Std, Dict[int, Std]] = Std.NONE, 215 tee: Union[Std, Dict[int, Std]] = Std.NONE, 216 local_ranks_filter: Optional[Set[int]] = None, 217 ) -> None: 218 self._root_log_dir = log_dir 219 self._redirects = redirects 220 self._tee = tee 221 self._local_ranks_filter = local_ranks_filter 222 223 @abstractmethod 224 def reify( 225 self, 226 envs: Dict[int, Dict[str, str]], 227 ) -> LogsDest: 228 """ 229 Given the environment variables, builds destination of log files for each of the local ranks. 230 231 Envs parameter contains env variables dict for each of the local ranks, where entries are defined in: 232 :func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`. 233 """ 234 235 @property 236 @abstractmethod 237 def root_log_dir(self) -> str: 238 pass 239 240 241class DefaultLogsSpecs(LogsSpecs): 242 """ 243 Default LogsSpecs implementation: 244 245 - `log_dir` will be created if it doesn't exist 246 - Generates nested folders for each attempt and rank. 247 """ 248 249 def __init__( 250 self, 251 log_dir: Optional[str] = None, 252 redirects: Union[Std, Dict[int, Std]] = Std.NONE, 253 tee: Union[Std, Dict[int, Std]] = Std.NONE, 254 local_ranks_filter: Optional[Set[int]] = None, 255 ) -> None: 256 if log_dir != os.devnull: 257 if not log_dir: 258 log_dir = tempfile.mkdtemp(prefix="torchelastic_") 259 elif not os.path.exists(log_dir): 260 os.makedirs(log_dir, exist_ok=True) 261 else: 262 if os.path.isfile(log_dir): 263 raise NotADirectoryError(f"log_dir: {log_dir} is a file") 264 super().__init__(log_dir, redirects, tee, local_ranks_filter) 265 # initialized only once 266 self._run_log_dir = None 267 268 @property 269 def root_log_dir(self) -> str: 270 return str(self._root_log_dir) 271 272 def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str): 273 base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_") 274 os.makedirs(base_log_dir, exist_ok=True) 275 dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir) 276 logger.info("log directory set to: %s", dir) 277 return dir 278 279 def reify( 280 self, 281 envs: Dict[int, Dict[str, str]], 282 ) -> LogsDest: 283 """ 284 Uses following scheme to build log destination paths: 285 286 - `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stdout.log` 287 - `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stderr.log` 288 - `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/error.json` 289 """ 290 nprocs = len(envs) 291 global_env = {} # use only to query properies that are not dependent on a rank 292 if nprocs > 0: 293 global_env = envs[0] 294 else: 295 logger.warning( 296 "Empty envs map provided when defining logging destinations." 297 ) 298 # Keys are always defined, but values can be missing in unit tests 299 run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id") 300 restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0") 301 302 attempt_log_dir: str = "" 303 if self._root_log_dir != os.devnull: 304 if not self._run_log_dir: 305 self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id) 306 307 attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload] 308 shutil.rmtree(attempt_log_dir, ignore_errors=True) 309 os.makedirs(attempt_log_dir) 310 311 if self._root_log_dir == os.devnull: 312 attempt_log_dir = os.devnull 313 314 # create subdirs for each local rank in the logs_dir 315 # logs_dir 316 # |- 0 317 # |- error.json 318 # |- stdout.log 319 # |- stderr.log 320 # |- ... 321 # |- (nprocs-1) 322 redirs = to_map(self._redirects, nprocs) 323 ts = to_map(self._tee, nprocs) 324 325 # to tee stdout/stderr we first redirect into a file 326 # then tail -f stdout.log/stderr.log so add tee settings to redirects 327 for local_rank, tee_std in ts.items(): 328 redirect_std = redirs[local_rank] 329 redirs[local_rank] = redirect_std | tee_std 330 331 SYS_STREAM = "" # special case to indicate to output to console 332 stdouts = dict.fromkeys(range(nprocs), SYS_STREAM) 333 stderrs = dict.fromkeys(range(nprocs), SYS_STREAM) 334 tee_stdouts: Dict[int, str] = {} 335 tee_stderrs: Dict[int, str] = {} 336 error_files = {} 337 338 for local_rank in range(nprocs): 339 if attempt_log_dir == os.devnull: 340 tee_stdouts[local_rank] = os.devnull 341 tee_stderrs[local_rank] = os.devnull 342 error_files[local_rank] = os.devnull 343 envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = "" 344 else: 345 clogdir = os.path.join(attempt_log_dir, str(local_rank)) 346 os.mkdir(clogdir) 347 348 rd = redirs[local_rank] 349 if (rd & Std.OUT) == Std.OUT: 350 stdouts[local_rank] = os.path.join(clogdir, "stdout.log") 351 if (rd & Std.ERR) == Std.ERR: 352 stderrs[local_rank] = os.path.join(clogdir, "stderr.log") 353 354 t = ts[local_rank] 355 if t & Std.OUT == Std.OUT: 356 tee_stdouts[local_rank] = stdouts[local_rank] 357 if t & Std.ERR == Std.ERR: 358 tee_stderrs[local_rank] = stderrs[local_rank] 359 360 if ( 361 self._local_ranks_filter 362 and local_rank not in self._local_ranks_filter 363 ): 364 # If stream is tee'd, only write to file, but don't tail 365 if local_rank in tee_stdouts: 366 tee_stdouts.pop(local_rank, None) 367 if local_rank in tee_stderrs: 368 tee_stderrs.pop(local_rank, None) 369 370 # If stream is not redirected, don't print 371 if stdouts[local_rank] == SYS_STREAM: 372 stdouts[local_rank] = os.devnull 373 if stderrs[local_rank] == SYS_STREAM: 374 stderrs[local_rank] = os.devnull 375 376 error_file = os.path.join(clogdir, "error.json") 377 error_files[local_rank] = error_file 378 logger.info( 379 "Setting worker%s reply file to: %s", local_rank, error_file 380 ) 381 envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file 382 383 return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files) 384 385 def __repr__(self) -> str: 386 return ( 387 f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, " 388 f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})" 389 ) 390 391 def __eq__(self, other: object) -> bool: 392 if not isinstance(other, DefaultLogsSpecs): 393 return False 394 395 return ( 396 self._root_log_dir == other._root_log_dir 397 and self._redirects == other._redirects 398 and self._tee == other._tee 399 and self._local_ranks_filter == other._local_ranks_filter 400 ) 401 402 403@dataclass 404class RunProcsResult: 405 """ 406 Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``. 407 408 Note the following: 409 410 1. All fields are mapped by local rank 411 2. ``return_values`` - only populated for functions (not the binaries). 412 3. ``stdouts`` - path to stdout.log (empty string if no redirect) 413 4. ``stderrs`` - path to stderr.log (empty string if no redirect) 414 415 """ 416 417 return_values: Dict[int, Any] = field(default_factory=dict) 418 failures: Dict[int, ProcessFailure] = field(default_factory=dict) 419 stdouts: Dict[int, str] = field(default_factory=dict) 420 stderrs: Dict[int, str] = field(default_factory=dict) 421 422 def is_failed(self) -> bool: 423 return len(self.failures) > 0 424 425 426class PContext(abc.ABC): 427 """ 428 The base class that standardizes operations over a set of processes that are launched via different mechanisms. 429 430 The name ``PContext`` is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``. 431 432 .. warning:: stdouts and stderrs should ALWAYS be a superset of 433 tee_stdouts and tee_stderrs (respectively) this is b/c 434 tee is implemented as a redirect + tail -f <stdout/stderr.log> 435 """ 436 437 def __init__( 438 self, 439 name: str, 440 entrypoint: Union[Callable, str], 441 args: Dict[int, Tuple], 442 envs: Dict[int, Dict[str, str]], 443 logs_specs: LogsSpecs, 444 log_line_prefixes: Optional[Dict[int, str]] = None, 445 ): 446 self.name = name 447 # validate that all mappings have the same number of keys and 448 # all local ranks are accounted for 449 nprocs = len(args) 450 451 # TODO log_line_prefixes can be exanded too 452 logs_dest = logs_specs.reify(envs) 453 454 _validate_full_rank(logs_dest.stdouts, nprocs, "stdouts") 455 _validate_full_rank(logs_dest.stderrs, nprocs, "stderrs") 456 457 self.entrypoint = entrypoint 458 self.args = args 459 self.envs = envs 460 self.stdouts = logs_dest.stdouts 461 self.stderrs = logs_dest.stderrs 462 self.error_files = logs_dest.error_files 463 self.nprocs = nprocs 464 465 self._stdout_tail = TailLog( 466 name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes 467 ) 468 self._stderr_tail = TailLog( 469 name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes 470 ) 471 472 def start(self) -> None: 473 """Start processes using parameters defined in the constructor.""" 474 if threading.current_thread() is threading.main_thread(): 475 signal.signal(signal.SIGTERM, _terminate_process_handler) 476 signal.signal(signal.SIGINT, _terminate_process_handler) 477 if not IS_WINDOWS: 478 signal.signal(signal.SIGHUP, _terminate_process_handler) 479 signal.signal(signal.SIGQUIT, _terminate_process_handler) 480 else: 481 logger.warning( 482 "Failed to register signal handlers since torchelastic is running on a child thread. " 483 "This could lead to orphaned worker processes if the torchrun is terminated." 484 ) 485 self._start() 486 self._stdout_tail.start() 487 self._stderr_tail.start() 488 489 @abc.abstractmethod 490 def _start(self) -> None: 491 """Start processes using strategy defined in a particular context.""" 492 raise NotImplementedError 493 494 @abc.abstractmethod 495 def _poll(self) -> Optional[RunProcsResult]: 496 """ 497 Poll the run status of the processes running under this context. 498 This method follows an "all-or-nothing" policy and returns 499 a ``RunProcessResults`` object if either all processes complete 500 successfully or any process fails. Returns ``None`` if 501 all processes are still running. 502 """ 503 raise NotImplementedError 504 505 def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]: 506 """ 507 Wait for the specified ``timeout`` seconds, polling every ``period`` seconds 508 for the processes to be done. Returns ``None`` if the processes are still running 509 on timeout expiry. Negative timeout values are interpreted as "wait-forever". 510 A timeout value of zero simply queries the status of the processes (e.g. equivalent 511 to a poll). 512 513 ..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise 514 ``SignalException`` when the signals received. It is up to the consumer of the code 515 to properly handle the exception. It is important not to swallow the exception otherwise 516 the process would not terminate. Example of the typical workflow can be: 517 518 .. code-block:: python 519 pc = start_processes(...) 520 try: 521 pc.wait(1) 522 .. do some other work 523 except SignalException as e: 524 pc.shutdown(e.sigval, timeout=30) 525 526 If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating 527 received signal. If child processes will not terminate in the timeout time, the process will send 528 the SIGKILL. 529 """ 530 if timeout == 0: 531 return self._poll() 532 533 if timeout < 0: 534 timeout = sys.maxsize 535 536 expiry = time.time() + timeout 537 while time.time() < expiry: 538 pr = self._poll() 539 if pr: 540 return pr 541 time.sleep(period) 542 543 return None 544 545 @abc.abstractmethod 546 def pids(self) -> Dict[int, int]: 547 """Return pids of processes mapped by their respective local_ranks.""" 548 raise NotImplementedError 549 550 @abc.abstractmethod 551 def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: 552 r""" 553 Terminates all processes managed by this context and cleans up any 554 meta resources (e.g. redirect, error_file files). 555 """ 556 raise NotImplementedError 557 558 def close( 559 self, death_sig: Optional[signal.Signals] = None, timeout: int = 30 560 ) -> None: 561 r""" 562 Terminates all processes managed by this context and cleans up any 563 meta resources (e.g. redirect, error_file files). 564 565 Args: 566 death_sig: Death signal to terminate processes. 567 timeout: Time to wait for processes to finish, if process is 568 still alive after this time, it will be terminated via SIGKILL. 569 """ 570 if not death_sig: 571 death_sig = _get_default_signal() 572 self._close(death_sig=death_sig, timeout=timeout) 573 if self._stdout_tail: 574 self._stdout_tail.stop() 575 if self._stderr_tail: 576 self._stderr_tail.stop() 577 578 579def get_std_cm(std_rd: str, redirect_fn): 580 if IS_WINDOWS or IS_MACOS or not std_rd: 581 return nullcontext() 582 else: 583 return redirect_fn(std_rd) 584 585 586def _wrap( 587 local_rank: int, 588 fn: Callable, 589 args: Dict[int, Tuple], 590 envs: Dict[int, Dict[str, str]], 591 stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None) 592 stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None) 593 ret_vals: Dict[int, mp.SimpleQueue], 594 queue_finished_reading_event: synchronize.Event, 595) -> None: 596 # get the per-rank params up front so we fail fast if no mapping is found 597 args_ = args[local_rank] 598 env_ = envs[local_rank] 599 ret_val_ = ret_vals[local_rank] 600 601 stdout_rd = stdout_redirects[local_rank] 602 stderr_rd = stderr_redirects[local_rank] 603 604 stdout_cm = get_std_cm(stdout_rd, redirect_stdout) 605 stderr_cm = get_std_cm(stderr_rd, redirect_stderr) 606 607 for k, v in env_.items(): 608 os.environ[k] = v 609 610 with stdout_cm, stderr_cm: 611 ret = record(fn)(*args_) 612 ret_val_.put(ret) 613 queue_finished_reading_event.wait() 614 615 616class MultiprocessContext(PContext): 617 """``PContext`` holding worker processes invoked as a function.""" 618 619 def __init__( 620 self, 621 name: str, 622 entrypoint: Callable, 623 args: Dict[int, Tuple], 624 envs: Dict[int, Dict[str, str]], 625 start_method: str, 626 logs_specs: LogsSpecs, 627 log_line_prefixes: Optional[Dict[int, str]] = None, 628 ): 629 super().__init__( 630 name, 631 entrypoint, 632 args, 633 envs, 634 logs_specs, 635 log_line_prefixes, 636 ) 637 638 self.start_method = start_method 639 # each ret_val queue will always contain a single element. 640 self._ret_vals = { 641 local_rank: mp.get_context(self.start_method).SimpleQueue() 642 for local_rank in range(self.nprocs) 643 } 644 645 # see comments in ``join()`` for what this is 646 self._return_values: Dict[int, Any] = {} 647 self._pc: Optional[mp.ProcessContext] = None 648 # Note: set method should ONLY be invoked for the use case when all processes finished 649 # successfully. If any process died on event.wait() calling set() method will deadlock. 650 self._worker_finished_event = mp.get_context(self.start_method).Event() 651 652 def _start(self): 653 if self._pc: 654 raise ValueError( 655 "The process context already initialized." 656 " Most likely the start method got called twice." 657 ) 658 self._pc = mp.start_processes( 659 fn=_wrap, 660 args=( 661 self.entrypoint, 662 self.args, 663 self.envs, 664 self.stdouts, 665 self.stderrs, 666 self._ret_vals, 667 self._worker_finished_event, 668 ), 669 nprocs=self.nprocs, 670 join=False, 671 daemon=False, 672 start_method=self.start_method, 673 ) 674 675 def _is_done(self) -> bool: 676 return len(self._return_values) == self.nprocs 677 678 def _poll(self) -> Optional[RunProcsResult]: 679 assert self._pc is not None # assertion for mypy type checker 680 681 try: 682 # torch.mp.ProcessContext Throws an Exception if some/all of 683 # worker processes failed 684 # timeout < 0 checks worker status and return immediately 685 # Join will never return success since we use synchronize.Event to wait 686 # for all processes to finish. 687 self._pc.join(-1) 688 689 # IMPORTANT: we use multiprocessing.Queue to carry worker return values 690 # back to the parent, the worker process will wait before terminating 691 # until all the buffered items are fed by the feeder thread to the underlying 692 # pipe. Hence to prevent deadlocks on large return values, 693 # we opportunistically try queue.get on each join call 694 # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms 695 for local_rank in range(0, self.nprocs): 696 return_queue = self._ret_vals[local_rank] 697 if not return_queue.empty(): 698 # save the return values temporarily into a member var 699 self._return_values[local_rank] = return_queue.get() 700 701 if self._is_done(): 702 # we should ALWAYS have ALL the return values when all the processes are done 703 self._worker_finished_event.set() 704 705 # At this point workers finished running the user function 706 # But the child process might still have not exited. Wait for them. 707 # pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits. 708 while not self._pc.join(): 709 logger.debug( 710 "entrypoint fn finished, waiting for all child procs to exit..." 711 ) 712 713 _validate_full_rank( 714 self._return_values, self.nprocs, "return_value queue" 715 ) 716 self.close() 717 return RunProcsResult( 718 return_values=self._return_values, 719 stdouts=self.stdouts, 720 stderrs=self.stderrs, 721 ) 722 else: 723 return None 724 except (mp.ProcessRaisedException, mp.ProcessExitedException) as e: 725 failed_local_rank = e.error_index 726 727 # entrypoint for MultiprocessContext will always be a Callable 728 fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr] 729 failed_proc = self._pc.processes[failed_local_rank] 730 error_filepath = self.error_files[failed_local_rank] 731 732 logger.exception( 733 "failed (exitcode: %s)" 734 " local_rank: %s (pid: %s)" 735 " of fn: %s (start_method: %s)", 736 failed_proc.exitcode, 737 failed_local_rank, 738 e.pid, 739 fn_name, 740 self.start_method, 741 ) 742 743 self.close() 744 return RunProcsResult( 745 failures={ 746 failed_local_rank: ProcessFailure( 747 local_rank=failed_local_rank, 748 pid=e.pid, 749 exitcode=failed_proc.exitcode, 750 error_file=error_filepath, 751 ) 752 }, 753 stdouts=self.stdouts, 754 stderrs=self.stderrs, 755 ) 756 757 def pids(self) -> Dict[int, int]: 758 assert self._pc is not None # assertion for mypy type checking 759 return dict(enumerate(self._pc.pids())) 760 761 def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: 762 if not self._pc: 763 return 764 for proc in self._pc.processes: 765 if proc.is_alive(): 766 logger.warning( 767 "Closing process %s via signal %s", proc.pid, death_sig.name 768 ) 769 try: 770 os.kill(proc.pid, death_sig) 771 except ProcessLookupError: 772 # If the process exited because of some reason, 773 # `ProcessLookupError` will be raised, it is safe to ignore it. 774 pass 775 end = time.monotonic() + timeout 776 for proc in self._pc.processes: 777 time_to_wait = end - time.monotonic() 778 if time_to_wait <= 0: 779 break 780 proc.join(time_to_wait) 781 for proc in self._pc.processes: 782 if proc.is_alive(): 783 logger.warning( 784 "Unable to shutdown process %s via %s, forcefully exiting via %s", 785 proc.pid, 786 death_sig, 787 _get_kill_signal(), 788 ) 789 try: 790 os.kill(proc.pid, _get_kill_signal()) 791 except ProcessLookupError: 792 # If the process exited because of some reason, 793 # `ProcessLookupError` will be raised, it is safe to ignore it. 794 pass 795 proc.join() 796 797 798class SubprocessContext(PContext): 799 """``PContext`` holding worker processes invoked as a binary.""" 800 801 def __init__( 802 self, 803 name: str, 804 entrypoint: str, 805 args: Dict[int, Tuple], 806 envs: Dict[int, Dict[str, str]], 807 logs_specs: LogsSpecs, 808 log_line_prefixes: Optional[Dict[int, str]] = None, 809 ): 810 super().__init__( 811 name, 812 entrypoint, 813 args, 814 envs, 815 logs_specs, 816 log_line_prefixes, 817 ) 818 819 # state vector; _vdone[local_rank] -> is local_rank finished or not 820 self._running_local_ranks: Set[int] = set(range(self.nprocs)) 821 self._failures: Dict[int, ProcessFailure] = {} 822 self.subprocess_handlers: Dict[int, SubprocessHandler] = {} 823 824 def _start(self): 825 if self.subprocess_handlers: 826 raise ValueError( 827 "The subprocess handlers already initialized. Most likely the start method got called twice." 828 ) 829 self.subprocess_handlers = { 830 local_rank: get_subprocess_handler( 831 entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str 832 args=self.args[local_rank], 833 env=self.envs[local_rank], 834 stdout=self.stdouts[local_rank], 835 stderr=self.stderrs[local_rank], 836 local_rank_id=local_rank, 837 ) 838 for local_rank in range(self.nprocs) 839 } 840 841 def _poll(self) -> Optional[RunProcsResult]: 842 done_local_ranks = set() 843 for local_rank in self._running_local_ranks: 844 handler = self.subprocess_handlers[local_rank] 845 exitcode = handler.proc.poll() 846 if exitcode is not None: 847 done_local_ranks.add(local_rank) 848 if exitcode != 0: # failed or signaled 849 self._failures[local_rank] = ProcessFailure( 850 local_rank=local_rank, 851 pid=handler.proc.pid, 852 exitcode=exitcode, 853 error_file=self.error_files[local_rank], 854 ) 855 # else: --> succeeded; nothing to do 856 857 self._running_local_ranks.difference_update(done_local_ranks) 858 859 # if ALL procs are finished or ANY have failed 860 if not self._running_local_ranks or self._failures: 861 self.close() # terminate all running procs 862 result = RunProcsResult( 863 failures=self._failures, 864 stdouts=self.stdouts, 865 stderrs=self.stderrs, 866 ) 867 if result.is_failed(): 868 first_failure = min(result.failures.values(), key=lambda f: f.timestamp) 869 logger.error( 870 "failed (exitcode: %s)" 871 " local_rank: %s (pid: %s)" 872 " of binary: %s", 873 first_failure.exitcode, 874 first_failure.local_rank, 875 first_failure.pid, 876 self.entrypoint, 877 ) 878 else: 879 # Populate return with dummy values. This provides consistency with MultiprocessingHandler 880 result.return_values = dict.fromkeys(range(self.nprocs)) 881 882 return result 883 else: # there are no failures and procs still running 884 return None 885 886 def pids(self) -> Dict[int, int]: 887 return { 888 local_rank: sh.proc.pid 889 for local_rank, sh in self.subprocess_handlers.items() 890 } 891 892 def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: 893 if not self.subprocess_handlers: 894 return 895 for handler in self.subprocess_handlers.values(): 896 if handler.proc.poll() is None: 897 logger.warning( 898 "Sending process %s closing signal %s", 899 handler.proc.pid, 900 death_sig.name, 901 ) 902 handler.close(death_sig=death_sig) 903 end = time.monotonic() + timeout 904 for handler in self.subprocess_handlers.values(): 905 time_to_wait = end - time.monotonic() 906 if time_to_wait <= 0: 907 break 908 try: 909 handler.proc.wait(time_to_wait) 910 except subprocess.TimeoutExpired: 911 # Ignore the timeout expired exception, since 912 # the child process will be forcefully terminated via SIGKILL 913 pass 914 for handler in self.subprocess_handlers.values(): 915 if handler.proc.poll() is None: 916 logger.warning( 917 "Unable to shutdown process %s via %s, forcefully exiting via %s", 918 handler.proc.pid, 919 death_sig, 920 _get_kill_signal(), 921 ) 922 handler.close(death_sig=_get_kill_signal()) 923 handler.proc.wait() 924