1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import asyncio 10import ctypes 11import multiprocessing 12import os 13import shutil 14import signal 15import sys 16import tempfile 17import time 18from itertools import product 19from typing import Callable, Dict, List, Union 20from unittest import mock 21 22import torch 23import torch.multiprocessing as mp 24from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes 25from torch.distributed.elastic.multiprocessing.api import ( 26 _validate_full_rank, 27 _wrap, 28 DefaultLogsSpecs, 29 MultiprocessContext, 30 RunProcsResult, 31 SignalException, 32 Std, 33 to_map, 34) 35from torch.distributed.elastic.multiprocessing.errors import ErrorHandler 36from torch.testing._internal.common_utils import ( 37 IS_CI, 38 IS_MACOS, 39 IS_WINDOWS, 40 NO_MULTIPROCESSING_SPAWN, 41 run_tests, 42 skip_but_pass_in_sandcastle_if, 43 skip_if_pytest, 44 TEST_WITH_ASAN, 45 TEST_WITH_DEV_DBG_ASAN, 46 TEST_WITH_TSAN, 47 TestCase, 48) 49 50 51class RunProcResultsTest(TestCase): 52 def setUp(self): 53 super().setUp() 54 self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") 55 56 def tearDown(self): 57 super().tearDown() 58 shutil.rmtree(self.test_dir) 59 60 def test_is_failed(self): 61 pr_success = RunProcsResult(return_values={0: "a", 1: "b"}) 62 self.assertFalse(pr_success.is_failed()) 63 64 fail0 = ProcessFailure( 65 local_rank=0, pid=998, exitcode=1, error_file="ignored.json" 66 ) 67 pr_fail = RunProcsResult(failures={0: fail0}) 68 self.assertTrue(pr_fail.is_failed()) 69 70 def test_get_failures(self): 71 error_file0 = os.path.join(self.test_dir, "error0.json") 72 error_file1 = os.path.join(self.test_dir, "error1.json") 73 eh = ErrorHandler() 74 with mock.patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": error_file0}): 75 eh.record_exception(RuntimeError("error 0")) 76 77 with mock.patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": error_file0}): 78 eh.record_exception(RuntimeError("error 1")) 79 80 fail0 = ProcessFailure( 81 local_rank=0, pid=997, exitcode=1, error_file=error_file0 82 ) 83 fail1 = ProcessFailure( 84 local_rank=1, pid=998, exitcode=3, error_file=error_file1 85 ) 86 fail2 = ProcessFailure( 87 local_rank=2, pid=999, exitcode=15, error_file="no_exist.json" 88 ) 89 90 self.assertLessEqual(fail0.timestamp, fail1.timestamp) 91 self.assertLessEqual(fail1.timestamp, fail2.timestamp) 92 93 94class StdTest(TestCase): 95 def test_from_value(self): 96 self.assertEqual(Std.NONE, Std.from_str("0")) 97 self.assertEqual(Std.OUT, Std.from_str("1")) 98 self.assertEqual(Std.ERR, Std.from_str("2")) 99 self.assertEqual(Std.ALL, Std.from_str("3")) 100 101 def test_from_value_map(self): 102 self.assertEqual({0: Std.OUT}, Std.from_str("0:1")) 103 self.assertEqual({0: Std.OUT, 1: Std.OUT}, Std.from_str("0:1,1:1")) 104 105 def test_from_str_bad_input(self): 106 bad_inputs = ["0:1,", "11", "0:1,1", "1,0:1"] 107 for bad in bad_inputs: 108 with self.subTest(bad=bad): 109 with self.assertRaises(ValueError): 110 Std.from_str(bad) 111 112 113def echo0(msg: str) -> None: 114 """ 115 void function 116 """ 117 print(msg) 118 119 120def echo1(msg: str, exitcode: int = 0) -> str: 121 """ 122 returns ``msg`` or exits with the given exitcode (if nonzero) 123 """ 124 125 rank = int(os.environ["RANK"]) 126 if exitcode != 0: 127 print(f"exit {exitcode} from {rank}", file=sys.stderr) 128 sys.exit(exitcode) 129 else: 130 print(f"{msg} stdout from {rank}") 131 print(f"{msg} stderr from {rank}", file=sys.stderr) 132 return f"{msg}_{rank}" 133 134 135def echo2(msg: str, fail: bool = False) -> str: 136 """ 137 returns ``msg`` or raises a RuntimeError if ``fail`` is set 138 """ 139 if fail: 140 raise RuntimeError(msg) 141 return msg 142 143 144def echo_large(size: int) -> Dict[int, str]: 145 """ 146 returns a large output ({0: test0", 1: "test1", ..., (size-1):f"test{size-1}"}) 147 """ 148 out = {} 149 for idx in range(0, size): 150 out[idx] = f"test{idx}" 151 return out 152 153 154def echo3(msg: str, fail: bool = False) -> str: 155 """ 156 returns ``msg`` or induces a SIGSEGV if ``fail`` is set 157 """ 158 if fail: 159 ctypes.string_at(0) 160 return msg 161 162 163def dummy_compute() -> torch.Tensor: 164 """ 165 returns a predefined size random Tensor 166 """ 167 return torch.rand(100, 100) 168 169 170def redirects_oss_test() -> List[Std]: 171 return [ 172 Std.NONE, 173 ] 174 175 176def redirects_all() -> List[Std]: 177 return [ 178 Std.NONE, 179 Std.OUT, 180 Std.ERR, 181 Std.ALL, 182 ] 183 184 185def bin(name: str): 186 dir = os.path.dirname(__file__) 187 return os.path.join(dir, "bin", name) 188 189 190def wait_fn(wait_time: int = 300) -> None: 191 time.sleep(wait_time) 192 print("Finished waiting") 193 194 195def start_processes_zombie_test( 196 idx: int, 197 entrypoint: Union[str, Callable], 198 mp_queue: mp.Queue, 199 log_dir: str, 200 nproc: int = 2, 201) -> None: 202 """ 203 Starts processes 204 """ 205 206 args = {} 207 envs = {} 208 for idx in range(nproc): 209 args[idx] = () 210 envs[idx] = {} 211 212 pc = start_processes( 213 name="zombie_test", 214 entrypoint=entrypoint, 215 args=args, 216 envs=envs, 217 logs_specs=DefaultLogsSpecs(log_dir=log_dir), 218 ) 219 my_pid = os.getpid() 220 mp_queue.put(my_pid) 221 for child_pid in pc.pids().values(): 222 mp_queue.put(child_pid) 223 224 try: 225 pc.wait(period=1, timeout=300) 226 except SignalException as e: 227 pc.close(e.sigval) 228 229 230class _StartProcessesTest(TestCase): 231 def setUp(self): 232 super().setUp() 233 self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") 234 self._start_methods = ["spawn"] 235 236 def tearDown(self): 237 super().tearDown() 238 shutil.rmtree(self.test_dir) 239 240 def log_dir(self): 241 return tempfile.mkdtemp(dir=self.test_dir) 242 243 def assert_in_file(self, expected: List[str], filename: str) -> None: 244 expected = [f"{line.rstrip()}\n" for line in expected] 245 with open(filename) as fp: 246 actual = fp.readlines() 247 for line in expected: 248 self.assertIn(line, actual) 249 250 def assert_pids_noexist(self, pids: Dict[int, int]): 251 for local_rank, pid in pids.items(): 252 with self.assertRaises( 253 OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist" 254 ): 255 os.kill(pid, 0) 256 257 def _test_zombie_workflow( 258 self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals 259 ) -> None: 260 mp_queue = mp.get_context("spawn").Queue() 261 child_nproc = 2 262 ctx = mp.spawn( 263 start_processes_zombie_test, 264 nprocs=1, 265 args=(entrypoint, mp_queue, self.log_dir(), child_nproc), 266 join=False, 267 ) 268 total_processes = child_nproc + 1 269 pids = [] 270 for _ in range(total_processes): 271 pids.append(mp_queue.get(timeout=120)) 272 parent_pid = pids[0] 273 child_pids = pids[1:] 274 275 os.kill(parent_pid, signal.SIGTERM) 276 # Wait to give time for signal handlers to finish work 277 time.sleep(5) 278 for child_pid in child_pids: 279 # Killing parent should kill all children, we expect that each call to 280 # os.kill would raise OSError 281 with self.assertRaises(OSError): 282 os.kill(child_pid, 0) 283 284 285# tests incompatible with tsan or asan 286if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): 287 288 class StartProcessesAsFuncTest(_StartProcessesTest): 289 def test_to_map(self): 290 local_world_size = 2 291 self.assertEqual( 292 {0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size) 293 ) 294 self.assertEqual( 295 {0: Std.NONE, 1: Std.OUT}, to_map({1: Std.OUT}, local_world_size) 296 ) 297 self.assertEqual( 298 {0: Std.ERR, 1: Std.OUT}, 299 to_map({0: Std.ERR, 1: Std.OUT}, local_world_size), 300 ) 301 302 def test_invalid_log_dir(self): 303 with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir: 304 cases = { 305 not_a_dir.name: NotADirectoryError, 306 } 307 308 for log_dir, expected_error in cases.items(): 309 with self.subTest(log_dir=log_dir, expected_error=expected_error): 310 with self.assertRaises(expected_error): 311 pc = None 312 try: 313 pc = start_processes( 314 name="echo", 315 entrypoint=echo1, 316 args={0: ("hello",)}, 317 envs={0: {"RANK": "0"}}, 318 logs_specs=DefaultLogsSpecs(log_dir=log_dir), 319 ) 320 finally: 321 if pc: 322 pc.close() 323 324 def test_args_env_len_mismatch(self): 325 cases = [ 326 # 1 x args; 2 x envs 327 { 328 "args": {0: ("hello",)}, 329 "envs": {0: {"RANK": "0"}, 1: {"RANK": "1"}}, 330 }, 331 # 2 x args; 1 x envs 332 { 333 "args": {0: ("hello",), 1: ("world",)}, 334 "envs": {0: {"RANK": "0"}}, 335 }, 336 ] 337 338 for kwds in cases: 339 args = kwds["args"] 340 envs = kwds["envs"] 341 with self.subTest(args=args, envs=envs): 342 with self.assertRaises(RuntimeError): 343 start_processes( 344 name="echo", 345 entrypoint=echo1, 346 args=args, 347 envs=envs, 348 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 349 ) 350 351 def test_pcontext_wait(self): 352 pc = start_processes( 353 name="sleep", 354 entrypoint=time.sleep, 355 args={0: (1,)}, 356 envs={0: {}}, 357 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 358 start_method="spawn", 359 ) 360 361 self.assertIsNone(pc.wait(timeout=0.1, period=0.01)) 362 self.assertIsNotNone(pc.wait(period=0.1)) 363 self.assertTrue(pc._stderr_tail.stopped()) 364 self.assertTrue(pc._stdout_tail.stopped()) 365 366 def test_pcontext_wait_on_a_child_thread(self): 367 asyncio.run(asyncio.to_thread(self.test_pcontext_wait)) 368 369 def test_multiprocess_context_close(self): 370 pc = start_processes( 371 name="sleep", 372 entrypoint=time.sleep, 373 args={0: (1,)}, 374 envs={0: {}}, 375 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 376 start_method="spawn", 377 ) 378 379 pids = pc.pids() 380 pc.close() 381 self.assert_pids_noexist(pids) 382 self.assertTrue(pc._stderr_tail.stopped()) 383 self.assertTrue(pc._stdout_tail.stopped()) 384 385 def test_function_with_tensor(self): 386 for start_method in self._start_methods: 387 pc = start_processes( 388 name="dummy_compute", 389 entrypoint=dummy_compute, 390 args={0: ()}, 391 envs={0: {}}, 392 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 393 start_method=start_method, 394 ) 395 396 results = pc.wait() 397 self.assert_pids_noexist(pc.pids()) 398 for return_value in results.return_values.values(): 399 self.assertIsInstance(return_value, torch.Tensor) 400 self.assertEqual((100, 100), return_value.shape) 401 402 def test_void_function(self): 403 for start_method in self._start_methods: 404 with self.subTest(start_method=start_method): 405 pc = start_processes( 406 name="echo", 407 entrypoint=echo0, 408 args={0: ("hello",), 1: ("world",)}, 409 envs={0: {}, 1: {}}, 410 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 411 start_method=start_method, 412 ) 413 414 results = pc.wait(period=0.1) 415 self.assertEqual({0: None, 1: None}, results.return_values) 416 417 @skip_but_pass_in_sandcastle_if( 418 TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan" 419 ) 420 def test_function_large_ret_val(self): 421 # python multiprocessing.queue module uses pipes and actually PipedQueues 422 # This means that if a single object is greater than a pipe size 423 # the writer process will block until reader process will start 424 # reading the pipe. 425 # This test makes a worker fn to return huge output, around ~10 MB 426 427 size = 200000 428 for start_method in self._start_methods: 429 with self.subTest(start_method=start_method): 430 pc = start_processes( 431 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 432 name="echo", 433 entrypoint=echo_large, 434 args={0: (size,), 1: (size,), 2: (size,), 3: (size,)}, 435 envs={0: {}, 1: {}, 2: {}, 3: {}}, 436 start_method=start_method, 437 ) 438 439 results = pc.wait(period=0.1) 440 for i in range(pc.nprocs): 441 self.assertEqual(size, len(results.return_values[i])) 442 443 def test_function_raise(self): 444 """ 445 run 2x copies of echo2, raise an exception on the first 446 """ 447 RAISE = True 448 449 for start_method in self._start_methods: 450 with self.subTest(start_method=start_method): 451 log_dir = self.log_dir() 452 pc = start_processes( 453 name="echo", 454 entrypoint=echo2, 455 args={0: ("hello", RAISE), 1: ("world",)}, 456 envs={ 457 0: {"TORCHELASTIC_RUN_ID": "run_id"}, 458 1: {"TORCHELASTIC_RUN_ID": "run_id"}, 459 }, 460 logs_specs=DefaultLogsSpecs(log_dir=log_dir), 461 start_method=start_method, 462 ) 463 464 results = pc.wait(period=0.1) 465 466 self.assert_pids_noexist(pc.pids()) 467 self.assertEqual(1, len(results.failures)) 468 self.assertFalse(results.return_values) 469 470 failure = results.failures[0] 471 error_file = failure.error_file 472 error_file_data = failure.error_file_data 473 474 self.assertEqual(1, failure.exitcode) 475 self.assertEqual("<N/A>", failure.signal_name()) 476 self.assertEqual(pc.pids()[0], failure.pid) 477 self.assertTrue( 478 error_file.startswith(os.path.join(log_dir, "run_id_")) 479 ) 480 self.assertTrue(error_file.endswith("attempt_0/0/error.json")) 481 self.assertEqual( 482 int(error_file_data["message"]["extraInfo"]["timestamp"]), 483 int(failure.timestamp), 484 ) 485 self.assertTrue(pc._stderr_tail.stopped()) 486 self.assertTrue(pc._stdout_tail.stopped()) 487 488 def test_wait_for_all_child_procs_to_exit(self): 489 """ 490 Tests that MultiprocessingContext actually waits for 491 the child process to exit (not just that the entrypoint fn has 492 finished running). 493 """ 494 495 mpc = MultiprocessContext( 496 name="echo", 497 entrypoint=echo0, 498 args={}, 499 envs={}, 500 start_method="spawn", 501 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 502 ) 503 504 with mock.patch.object( 505 mpc, "_is_done", return_value=True 506 ), mock.patch.object(mpc, "_pc"), mock.patch.object( 507 mpc._pc, "join", side_effect=[True, False, False, True] 508 ) as mock_join: 509 mpc._poll() 510 self.assertEqual(4, mock_join.call_count) 511 512 @skip_but_pass_in_sandcastle_if( 513 NO_MULTIPROCESSING_SPAWN, 514 "Disabled for environments that \ 515 don't support multiprocessing with spawn start method", 516 ) 517 def test_multiprocessing_context_poll_raises_exception(self): 518 mp_context = MultiprocessContext( 519 name="test_mp", 520 entrypoint=echo0, 521 args={0: (0, 1)}, 522 envs={0: {}}, 523 logs_specs=DefaultLogsSpecs( 524 log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL 525 ), 526 start_method="spawn", 527 ) 528 mp_context._pc = mock.Mock() 529 # Using mock since we cannot just set exitcode on process 530 mock_process = mock.Mock() 531 mock_process.exitcode = -1 532 mp_context._pc.processes = [mock_process] 533 e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123) 534 mp_context._pc.join.side_effect = e 535 with mock.patch.object(mp_context, "close"): 536 run_result = mp_context._poll() 537 self.assertEqual(1, len(run_result.failures)) 538 failure = run_result.failures[0] 539 self.assertEqual( 540 "Signal 1 (SIGHUP) received by PID 123", failure.message 541 ) 542 543 class StartProcessesAsBinaryTest(_StartProcessesTest): 544 ######################################## 545 # start_processes as binary tests 546 ######################################## 547 548 def test_subprocess_context_close(self): 549 pc = start_processes( 550 name="sleep", 551 entrypoint=bin("zombie_test.py"), 552 args={0: (1,)}, 553 envs={0: {}}, 554 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 555 ) 556 557 pids = pc.pids() 558 pc.close() 559 self.assert_pids_noexist(pids) 560 561 def test_binary_exit(self): 562 FAIL = 138 563 pc = start_processes( 564 name="echo", 565 entrypoint=bin("echo1.py"), 566 args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")}, 567 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 568 logs_specs=DefaultLogsSpecs( 569 log_dir=self.log_dir(), 570 redirects={0: Std.ALL}, 571 ), 572 ) 573 574 results = pc.wait(period=0.1) 575 576 self.assertTrue(results.is_failed()) 577 self.assertEqual(1, len(results.failures)) 578 579 failure = results.failures[0] 580 self.assertEqual(138, failure.exitcode) 581 self.assertEqual("<N/A>", failure.signal_name()) 582 self.assertEqual("<NONE>", failure.error_file_data["message"]) 583 self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0]) 584 self.assert_in_file([], results.stdouts[0]) 585 self.assertFalse(results.stderrs[1]) 586 self.assertFalse(results.stdouts[1]) 587 self.assertTrue(pc._stderr_tail.stopped()) 588 self.assertTrue(pc._stdout_tail.stopped()) 589 590 def test_binary_raises(self): 591 pc = start_processes( 592 name="echo", 593 entrypoint=bin("echo2.py"), 594 args={0: ("--raises", "true", "foo"), 1: ("bar",)}, 595 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 596 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 597 ) 598 599 results = pc.wait(period=0.1) 600 601 self.assert_pids_noexist(pc.pids()) 602 self.assertTrue(results.is_failed()) 603 self.assertEqual(1, len(results.failures)) 604 605 failure = results.failures[0] 606 self.assertEqual(1, failure.exitcode) 607 self.assertEqual("<NONE>", failure.error_file_data["message"]) 608 self.assertEqual("<N/A>", failure.signal_name()) 609 610 def test_binary_incorrect_entrypoint(self): 611 with self.assertRaises(FileNotFoundError): 612 start_processes( 613 name="echo", 614 entrypoint="does_not_exist.py", 615 args={0: ("foo"), 1: ("bar",)}, 616 envs={0: {}, 1: {}}, 617 logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), 618 ) 619 620 def test_validate_full_rank(self): 621 with self.assertRaises(RuntimeError): 622 _validate_full_rank({}, 10, "") 623 624 625# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows 626if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): 627 628 class StartProcessesListAsFuncTest(_StartProcessesTest): 629 def test_function(self): 630 for start_method, redirs in product( 631 self._start_methods, redirects_oss_test() 632 ): 633 with self.subTest(start_method=start_method, redirs=redirs): 634 pc = start_processes( 635 name="echo", 636 entrypoint=echo1, 637 args={0: ("hello",), 1: ("hello",)}, 638 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 639 logs_specs=DefaultLogsSpecs( 640 log_dir=self.log_dir(), 641 redirects=redirs, 642 ), 643 start_method=start_method, 644 ) 645 646 results = pc.wait(period=0.1) 647 nprocs = pc.nprocs 648 649 self.assert_pids_noexist(pc.pids()) 650 self.assertEqual( 651 {i: f"hello_{i}" for i in range(nprocs)}, results.return_values 652 ) 653 654 for i in range(nprocs): 655 if redirs & Std.OUT != Std.OUT: 656 self.assertFalse(results.stdouts[i]) 657 if redirs & Std.ERR != Std.ERR: 658 self.assertFalse(results.stderrs[i]) 659 if redirs & Std.OUT == Std.OUT: 660 self.assert_in_file( 661 [f"hello stdout from {i}"], results.stdouts[i] 662 ) 663 if redirs & Std.ERR == Std.ERR: 664 self.assert_in_file( 665 [f"hello stderr from {i}"], results.stderrs[i] 666 ) 667 668 class StartProcessesListAsBinaryTest(_StartProcessesTest): 669 ######################################## 670 # start_processes as binary tests 671 ######################################## 672 def test_binary(self): 673 for redirs in redirects_oss_test(): 674 with self.subTest(redirs=redirs): 675 pc = start_processes( 676 name="echo", 677 entrypoint=bin("echo1.py"), 678 args={0: ("hello",), 1: ("hello",)}, 679 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 680 logs_specs=DefaultLogsSpecs( 681 log_dir=self.log_dir(), 682 redirects=redirs, 683 ), 684 log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"}, 685 ) 686 687 results = pc.wait(period=0.1) 688 689 self.assert_pids_noexist(pc.pids()) 690 # currently binaries return {rank: None} 691 self.assertEqual(2, len(results.return_values)) 692 self.assertFalse(results.is_failed()) 693 694 nprocs = pc.nprocs 695 for i in range(nprocs): 696 if redirs & Std.OUT != Std.OUT: 697 self.assertFalse(results.stdouts[i]) 698 if redirs & Std.ERR != Std.ERR: 699 self.assertFalse(results.stderrs[i]) 700 if redirs & Std.OUT == Std.OUT: 701 self.assert_in_file( 702 [f"hello stdout from {i}"], results.stdouts[i] 703 ) 704 if redirs & Std.ERR == Std.ERR: 705 self.assert_in_file( 706 [f"hello stderr from {i}"], results.stderrs[i] 707 ) 708 709 def test_binary_redirect_and_tee(self): 710 pc = start_processes( 711 name="trainer", 712 entrypoint=bin("echo1.py"), 713 args={0: ("hello",), 1: ("world",)}, 714 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 715 logs_specs=DefaultLogsSpecs( 716 log_dir=self.log_dir(), 717 redirects={0: Std.ERR, 1: Std.NONE}, 718 tee={0: Std.OUT, 1: Std.ERR}, 719 ), 720 log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"}, 721 start_method="spawn", 722 ) 723 724 result = pc.wait() 725 726 self.assertFalse(result.is_failed()) 727 self.assert_in_file(["hello stdout from 0"], pc.stdouts[0]) 728 self.assert_in_file(["hello stderr from 0"], pc.stderrs[0]) 729 self.assert_in_file(["world stderr from 1"], pc.stderrs[1]) 730 self.assertFalse(pc.stdouts[1]) 731 self.assertTrue(pc._stderr_tail.stopped()) 732 self.assertTrue(pc._stdout_tail.stopped()) 733 734 735# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows 736if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI): 737 738 class StartProcessesNotCIAsFuncTest(_StartProcessesTest): 739 @skip_if_pytest 740 def test_wrap_bad(self): 741 none = "" 742 stdout_log = os.path.join(self.test_dir, "stdout.log") 743 stderr_log = os.path.join(self.test_dir, "stderr.log") 744 redirs = [ 745 (none, none), 746 (none, stderr_log), 747 (stdout_log, none), 748 (stdout_log, stderr_log), 749 ] 750 751 for stdout_redir, stderr_redir in redirs: 752 queue = multiprocessing.SimpleQueue() 753 worker_finished_event_mock = mock.Mock() 754 _wrap( 755 local_rank=0, 756 fn=echo1, 757 args={0: ("hello",)}, 758 envs={0: {"RANK": "0"}}, 759 stdout_redirects={0: stdout_redir}, 760 stderr_redirects={0: stderr_redir}, 761 ret_vals={0: queue}, 762 queue_finished_reading_event=worker_finished_event_mock, 763 ) 764 self.assertEqual("hello_0", queue.get()) 765 if stdout_redir: 766 self.assert_in_file(["hello stdout from 0"], stdout_log) 767 if stderr_redir: 768 self.assert_in_file(["hello stderr from 0"], stderr_log) 769 worker_finished_event_mock.wait.assert_called_once() 770 771 def test_function_redirect_and_tee(self): 772 for start_method in self._start_methods: 773 with self.subTest(start_method=start_method): 774 pc = start_processes( 775 name="trainer", 776 entrypoint=echo1, 777 args={0: ("hello",), 1: ("world",)}, 778 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 779 logs_specs=DefaultLogsSpecs( 780 log_dir=self.log_dir(), 781 redirects={0: Std.ERR, 1: Std.NONE}, 782 tee={0: Std.OUT, 1: Std.ERR}, 783 ), 784 start_method="spawn", 785 ) 786 787 result = pc.wait() 788 789 self.assertFalse(result.is_failed()) 790 self.assert_in_file(["hello stdout from 0"], pc.stdouts[0]) 791 self.assert_in_file(["hello stderr from 0"], pc.stderrs[0]) 792 self.assert_in_file(["world stderr from 1"], pc.stderrs[1]) 793 self.assertFalse(pc.stdouts[1]) 794 self.assertTrue(pc._stderr_tail.stopped()) 795 self.assertTrue(pc._stdout_tail.stopped()) 796 797 def test_function(self): 798 for start_method, redirs in product(self._start_methods, redirects_all()): 799 with self.subTest(start_method=start_method, redirs=redirs): 800 pc = start_processes( 801 name="echo", 802 entrypoint=echo1, 803 args={0: ("hello",), 1: ("hello",)}, 804 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 805 start_method=start_method, 806 logs_specs=DefaultLogsSpecs( 807 log_dir=self.log_dir(), 808 redirects=redirs, 809 ), 810 ) 811 812 results = pc.wait(period=0.1) 813 nprocs = pc.nprocs 814 815 self.assert_pids_noexist(pc.pids()) 816 self.assertEqual( 817 {i: f"hello_{i}" for i in range(nprocs)}, results.return_values 818 ) 819 820 for i in range(nprocs): 821 if redirs & Std.OUT != Std.OUT: 822 self.assertFalse(results.stdouts[i]) 823 if redirs & Std.ERR != Std.ERR: 824 self.assertFalse(results.stderrs[i]) 825 if redirs & Std.OUT == Std.OUT: 826 self.assert_in_file( 827 [f"hello stdout from {i}"], results.stdouts[i] 828 ) 829 if redirs & Std.ERR == Std.ERR: 830 self.assert_in_file( 831 [f"hello stderr from {i}"], results.stderrs[i] 832 ) 833 834 def test_function_exit(self): 835 """ 836 run 2x copies of echo1 fail (exit) the first 837 functions that exit from python do not generate an error file 838 (even if they are decorated with @record) 839 """ 840 841 FAIL = 138 842 for start_method in self._start_methods: 843 with self.subTest(start_method=start_method): 844 pc = start_processes( 845 name="echo", 846 entrypoint=echo1, 847 args={0: ("hello", FAIL), 1: ("hello",)}, 848 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 849 logs_specs=DefaultLogsSpecs( 850 log_dir=self.log_dir(), 851 redirects={0: Std.ERR}, 852 ), 853 start_method=start_method, 854 ) 855 856 results = pc.wait(period=0.1) 857 858 self.assert_pids_noexist(pc.pids()) 859 self.assertTrue(results.is_failed()) 860 self.assertEqual(1, len(results.failures)) 861 self.assertFalse(results.return_values) 862 863 failure = results.failures[0] 864 error_file = failure.error_file 865 866 self.assertEqual(FAIL, failure.exitcode) 867 self.assertEqual("<N/A>", failure.signal_name()) 868 self.assertEqual(pc.pids()[0], failure.pid) 869 self.assertEqual("<N/A>", error_file) 870 self.assertEqual( 871 "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html", 872 failure.message, 873 ) 874 self.assertLessEqual(failure.timestamp, int(time.time())) 875 876 self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0]) 877 self.assertFalse(results.stdouts[0]) 878 self.assertFalse(results.stderrs[1]) 879 self.assertFalse(results.stdouts[1]) 880 self.assertTrue(pc._stderr_tail.stopped()) 881 self.assertTrue(pc._stdout_tail.stopped()) 882 883 def test_no_zombie_process_function(self): 884 signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] 885 for s in signals: 886 self._test_zombie_workflow(wait_fn, s) 887 888 class StartProcessesNotCIAsBinaryTest(_StartProcessesTest): 889 def test_binary_signal(self): 890 pc = start_processes( 891 name="echo", 892 entrypoint=bin("echo3.py"), 893 args={0: ("--segfault", "true", "foo"), 1: ("bar",)}, 894 envs={0: {"RANK": "0"}, 1: {"RANK": "1"}}, 895 logs_specs=DefaultLogsSpecs( 896 log_dir=self.log_dir(), 897 ), 898 ) 899 900 results = pc.wait(period=0.1) 901 902 self.assert_pids_noexist(pc.pids()) 903 self.assertTrue(results.is_failed()) 904 self.assertEqual(1, len(results.failures)) 905 906 failure = results.failures[0] 907 self.assertNotEqual(signal.SIGSEGV, failure.exitcode) 908 if TEST_WITH_ASAN or TEST_WITH_TSAN: 909 # ASAN/TSAN exit code is 1. 910 self.assertEqual("<N/A>", failure.signal_name()) 911 else: 912 self.assertEqual("SIGSEGV", failure.signal_name()) 913 self.assertEqual("<NONE>", failure.error_file_data["message"]) 914 915 def test_no_zombie_process_binary(self): 916 signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT] 917 for s in signals: 918 self._test_zombie_workflow(bin("zombie_test.py"), s) 919 920 class ForkServerTest( 921 StartProcessesAsFuncTest, 922 StartProcessesListAsFuncTest, 923 StartProcessesNotCIAsFuncTest, 924 ): 925 def setUp(self): 926 super().setUp() 927 self._start_methods = ["forkserver"] 928 self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) 929 os.environ[mp.ENV_VAR_PARALLEL_START] = "1" 930 931 def tearDown(self): 932 super().tearDown() 933 if self.orig_paralell_env_val is None: 934 del os.environ[mp.ENV_VAR_PARALLEL_START] 935 else: 936 os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val 937 938 939if __name__ == "__main__": 940 run_tests() 941