# Owner(s): ["module: multiprocessing"] import os import pickle import random import signal import sys import time import unittest import torch.multiprocessing as mp from torch.testing._internal.common_utils import ( IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, run_tests, TestCase, ) def _test_success_func(i): pass def _test_success_single_arg_func(i, arg): if arg: arg.put(i) def _test_exception_single_func(i, arg): if i == arg: raise ValueError("legitimate exception from process %d" % i) time.sleep(1.0) def _test_exception_all_func(i): time.sleep(random.random() / 10) raise ValueError("legitimate exception from process %d" % i) def _test_terminate_signal_func(i): if i == 0: os.kill(os.getpid(), signal.SIGABRT) time.sleep(1.0) def _test_terminate_exit_func(i, arg): if i == 0: sys.exit(arg) time.sleep(1.0) def _test_success_first_then_exception_func(i, arg): if i == 0: return time.sleep(0.1) raise ValueError("legitimate exception") def _test_nested_child_body(i, ready_queue, nested_child_sleep): ready_queue.put(None) time.sleep(nested_child_sleep) def _test_infinite_task(i): while True: time.sleep(1) def _test_process_exit(idx): sys.exit(12) def _test_nested(i, pids_queue, nested_child_sleep, start_method): context = mp.get_context(start_method) nested_child_ready_queue = context.Queue() nprocs = 2 mp_context = mp.start_processes( fn=_test_nested_child_body, args=(nested_child_ready_queue, nested_child_sleep), nprocs=nprocs, join=False, daemon=False, start_method=start_method, ) pids_queue.put(mp_context.pids()) # Wait for both children to have started, to ensure that they # have called prctl(2) to register a parent death signal. for _ in range(nprocs): nested_child_ready_queue.get() # Kill self. This should take down the child processes as well. os.kill(os.getpid(), signal.SIGTERM) class _TestMultiProcessing: start_method = None def test_success(self): mp.start_processes(_test_success_func, nprocs=2, start_method=self.start_method) def test_success_non_blocking(self): mp_context = mp.start_processes(_test_success_func, nprocs=2, join=False, start_method=self.start_method) # After all processes (nproc=2) have joined it must return True mp_context.join(timeout=None) mp_context.join(timeout=None) self.assertTrue(mp_context.join(timeout=None)) def test_first_argument_index(self): context = mp.get_context(self.start_method) queue = context.SimpleQueue() mp.start_processes(_test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method) self.assertEqual([0, 1], sorted([queue.get(), queue.get()])) def test_exception_single(self): nprocs = 2 for i in range(nprocs): with self.assertRaisesRegex( Exception, "\nValueError: legitimate exception from process %d$" % i, ): mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method) def test_exception_all(self): with self.assertRaisesRegex( Exception, "\nValueError: legitimate exception from process (0|1)$", ): mp.start_processes(_test_exception_all_func, nprocs=2, start_method=self.start_method) def test_terminate_signal(self): # SIGABRT is aliased with SIGIOT message = "process 0 terminated with signal (SIGABRT|SIGIOT)" # Termination through with signal is expressed as a negative exit code # in multiprocessing, so we know it was a signal that caused the exit. # This doesn't appear to exist on Windows, where the exit code is always # positive, and therefore results in a different exception message. # Exit code 22 means "ERROR_BAD_COMMAND". if IS_WINDOWS: message = "process 0 terminated with exit code 22" with self.assertRaisesRegex(Exception, message): mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method) def test_terminate_exit(self): exitcode = 123 with self.assertRaisesRegex( Exception, "process 0 terminated with exit code %d" % exitcode, ): mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method) def test_success_first_then_exception(self): exitcode = 123 with self.assertRaisesRegex( Exception, "ValueError: legitimate exception", ): mp.start_processes(_test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method) @unittest.skipIf( sys.platform != "linux", "Only runs on Linux; requires prctl(2)", ) def _test_nested(self): context = mp.get_context(self.start_method) pids_queue = context.Queue() nested_child_sleep = 20.0 mp_context = mp.start_processes( fn=_test_nested, args=(pids_queue, nested_child_sleep, self.start_method), nprocs=1, join=False, daemon=False, start_method=self.start_method, ) # Wait for nested children to terminate in time pids = pids_queue.get() start = time.time() while len(pids) > 0: for pid in pids: try: os.kill(pid, 0) except ProcessLookupError: pids.remove(pid) break # This assert fails if any nested child process is still # alive after (nested_child_sleep / 2) seconds. By # extension, this test times out with an assertion error # after (nested_child_sleep / 2) seconds. self.assertLess(time.time() - start, nested_child_sleep / 2) time.sleep(0.1) @unittest.skipIf( NO_MULTIPROCESSING_SPAWN, "Disabled for environments that don't support the spawn start method") class SpawnTest(TestCase, _TestMultiProcessing): start_method = 'spawn' def test_exception_raises(self): with self.assertRaises(mp.ProcessRaisedException): mp.spawn(_test_success_first_then_exception_func, args=(), nprocs=1) def test_signal_raises(self): context = mp.spawn(_test_infinite_task, args=(), nprocs=1, join=False) for pid in context.pids(): os.kill(pid, signal.SIGTERM) with self.assertRaises(mp.ProcessExitedException): context.join() def _test_process_exited(self): with self.assertRaises(mp.ProcessExitedException) as e: mp.spawn(_test_process_exit, args=(), nprocs=1) self.assertEqual(12, e.exit_code) @unittest.skipIf( IS_WINDOWS, "Fork is only available on Unix", ) class ForkTest(TestCase, _TestMultiProcessing): start_method = 'fork' @unittest.skipIf( IS_WINDOWS, "Fork is only available on Unix", ) class ParallelForkServerShouldWorkTest(TestCase, _TestMultiProcessing): orig_paralell_env_val = None def setUp(self): super().setUp() self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) os.environ[mp.ENV_VAR_PARALLEL_START] = "1" def tearDown(self): super().tearDown() if self.orig_paralell_env_val is None: del os.environ[mp.ENV_VAR_PARALLEL_START] else: os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val @unittest.skipIf( IS_WINDOWS, "Fork is only available on Unix", ) class ParallelForkServerPerfTest(TestCase): def test_forkserver_perf(self): start_method = 'forkserver' expensive = Expensive() nprocs = 4 orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START) # test the non parallel case os.environ[mp.ENV_VAR_PARALLEL_START] = "0" start = time.perf_counter() mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) elapsed = time.perf_counter() - start # the elapsed time should be at least {nprocs}x the sleep time self.assertGreaterEqual(elapsed, Expensive.SLEEP_SECS * nprocs) # test the parallel case os.environ[mp.ENV_VAR_PARALLEL_START] = "1" start = time.perf_counter() mp.start_processes(expensive.my_call, nprocs=nprocs, start_method=start_method) elapsed = time.perf_counter() - start # the elapsed time should be less than {nprocs}x the sleep time self.assertLess(elapsed, Expensive.SLEEP_SECS * nprocs) if orig_paralell_env_val is None: del os.environ[mp.ENV_VAR_PARALLEL_START] else: os.environ[mp.ENV_VAR_PARALLEL_START] = orig_paralell_env_val class Expensive: SLEEP_SECS = 5 # Simulate startup overhead such as large imports time.sleep(SLEEP_SECS) def __init__(self): self.config: str = "*" * 1000000 def my_call(self, *args): pass class ErrorTest(TestCase): def test_errors_pickleable(self): for error in ( mp.ProcessRaisedException("Oh no!", 1, 1), mp.ProcessExitedException("Oh no!", 1, 1, 1), ): pickle.loads(pickle.dumps(error)) if __name__ == '__main__': run_tests()