xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/multiprocessing/api_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import asyncio
10import ctypes
11import multiprocessing
12import os
13import shutil
14import signal
15import sys
16import tempfile
17import time
18from itertools import product
19from typing import Callable, Dict, List, Union
20from unittest import mock
21
22import torch
23import torch.multiprocessing as mp
24from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
25from torch.distributed.elastic.multiprocessing.api import (
26    _validate_full_rank,
27    _wrap,
28    DefaultLogsSpecs,
29    MultiprocessContext,
30    RunProcsResult,
31    SignalException,
32    Std,
33    to_map,
34)
35from torch.distributed.elastic.multiprocessing.errors import ErrorHandler
36from torch.testing._internal.common_utils import (
37    IS_CI,
38    IS_MACOS,
39    IS_WINDOWS,
40    NO_MULTIPROCESSING_SPAWN,
41    run_tests,
42    skip_but_pass_in_sandcastle_if,
43    skip_if_pytest,
44    TEST_WITH_ASAN,
45    TEST_WITH_DEV_DBG_ASAN,
46    TEST_WITH_TSAN,
47    TestCase,
48)
49
50
51class RunProcResultsTest(TestCase):
52    def setUp(self):
53        super().setUp()
54        self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
55
56    def tearDown(self):
57        super().tearDown()
58        shutil.rmtree(self.test_dir)
59
60    def test_is_failed(self):
61        pr_success = RunProcsResult(return_values={0: "a", 1: "b"})
62        self.assertFalse(pr_success.is_failed())
63
64        fail0 = ProcessFailure(
65            local_rank=0, pid=998, exitcode=1, error_file="ignored.json"
66        )
67        pr_fail = RunProcsResult(failures={0: fail0})
68        self.assertTrue(pr_fail.is_failed())
69
70    def test_get_failures(self):
71        error_file0 = os.path.join(self.test_dir, "error0.json")
72        error_file1 = os.path.join(self.test_dir, "error1.json")
73        eh = ErrorHandler()
74        with mock.patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": error_file0}):
75            eh.record_exception(RuntimeError("error 0"))
76
77        with mock.patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": error_file0}):
78            eh.record_exception(RuntimeError("error 1"))
79
80        fail0 = ProcessFailure(
81            local_rank=0, pid=997, exitcode=1, error_file=error_file0
82        )
83        fail1 = ProcessFailure(
84            local_rank=1, pid=998, exitcode=3, error_file=error_file1
85        )
86        fail2 = ProcessFailure(
87            local_rank=2, pid=999, exitcode=15, error_file="no_exist.json"
88        )
89
90        self.assertLessEqual(fail0.timestamp, fail1.timestamp)
91        self.assertLessEqual(fail1.timestamp, fail2.timestamp)
92
93
94class StdTest(TestCase):
95    def test_from_value(self):
96        self.assertEqual(Std.NONE, Std.from_str("0"))
97        self.assertEqual(Std.OUT, Std.from_str("1"))
98        self.assertEqual(Std.ERR, Std.from_str("2"))
99        self.assertEqual(Std.ALL, Std.from_str("3"))
100
101    def test_from_value_map(self):
102        self.assertEqual({0: Std.OUT}, Std.from_str("0:1"))
103        self.assertEqual({0: Std.OUT, 1: Std.OUT}, Std.from_str("0:1,1:1"))
104
105    def test_from_str_bad_input(self):
106        bad_inputs = ["0:1,", "11", "0:1,1", "1,0:1"]
107        for bad in bad_inputs:
108            with self.subTest(bad=bad):
109                with self.assertRaises(ValueError):
110                    Std.from_str(bad)
111
112
113def echo0(msg: str) -> None:
114    """
115    void function
116    """
117    print(msg)
118
119
120def echo1(msg: str, exitcode: int = 0) -> str:
121    """
122    returns ``msg`` or exits with the given exitcode (if nonzero)
123    """
124
125    rank = int(os.environ["RANK"])
126    if exitcode != 0:
127        print(f"exit {exitcode} from {rank}", file=sys.stderr)
128        sys.exit(exitcode)
129    else:
130        print(f"{msg} stdout from {rank}")
131        print(f"{msg} stderr from {rank}", file=sys.stderr)
132        return f"{msg}_{rank}"
133
134
135def echo2(msg: str, fail: bool = False) -> str:
136    """
137    returns ``msg`` or raises a RuntimeError if ``fail`` is set
138    """
139    if fail:
140        raise RuntimeError(msg)
141    return msg
142
143
144def echo_large(size: int) -> Dict[int, str]:
145    """
146    returns a large output ({0: test0", 1: "test1", ..., (size-1):f"test{size-1}"})
147    """
148    out = {}
149    for idx in range(0, size):
150        out[idx] = f"test{idx}"
151    return out
152
153
154def echo3(msg: str, fail: bool = False) -> str:
155    """
156    returns ``msg`` or induces a SIGSEGV if ``fail`` is set
157    """
158    if fail:
159        ctypes.string_at(0)
160    return msg
161
162
163def dummy_compute() -> torch.Tensor:
164    """
165    returns a predefined size random Tensor
166    """
167    return torch.rand(100, 100)
168
169
170def redirects_oss_test() -> List[Std]:
171    return [
172        Std.NONE,
173    ]
174
175
176def redirects_all() -> List[Std]:
177    return [
178        Std.NONE,
179        Std.OUT,
180        Std.ERR,
181        Std.ALL,
182    ]
183
184
185def bin(name: str):
186    dir = os.path.dirname(__file__)
187    return os.path.join(dir, "bin", name)
188
189
190def wait_fn(wait_time: int = 300) -> None:
191    time.sleep(wait_time)
192    print("Finished waiting")
193
194
195def start_processes_zombie_test(
196    idx: int,
197    entrypoint: Union[str, Callable],
198    mp_queue: mp.Queue,
199    log_dir: str,
200    nproc: int = 2,
201) -> None:
202    """
203    Starts processes
204    """
205
206    args = {}
207    envs = {}
208    for idx in range(nproc):
209        args[idx] = ()
210        envs[idx] = {}
211
212    pc = start_processes(
213        name="zombie_test",
214        entrypoint=entrypoint,
215        args=args,
216        envs=envs,
217        logs_specs=DefaultLogsSpecs(log_dir=log_dir),
218    )
219    my_pid = os.getpid()
220    mp_queue.put(my_pid)
221    for child_pid in pc.pids().values():
222        mp_queue.put(child_pid)
223
224    try:
225        pc.wait(period=1, timeout=300)
226    except SignalException as e:
227        pc.close(e.sigval)
228
229
230class _StartProcessesTest(TestCase):
231    def setUp(self):
232        super().setUp()
233        self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
234        self._start_methods = ["spawn"]
235
236    def tearDown(self):
237        super().tearDown()
238        shutil.rmtree(self.test_dir)
239
240    def log_dir(self):
241        return tempfile.mkdtemp(dir=self.test_dir)
242
243    def assert_in_file(self, expected: List[str], filename: str) -> None:
244        expected = [f"{line.rstrip()}\n" for line in expected]
245        with open(filename) as fp:
246            actual = fp.readlines()
247            for line in expected:
248                self.assertIn(line, actual)
249
250    def assert_pids_noexist(self, pids: Dict[int, int]):
251        for local_rank, pid in pids.items():
252            with self.assertRaises(
253                OSError, msg=f"local_rank: {local_rank} pid: {pid} should not exist"
254            ):
255                os.kill(pid, 0)
256
257    def _test_zombie_workflow(
258        self, entrypoint: Union[str, Callable], signal_to_send: signal.Signals
259    ) -> None:
260        mp_queue = mp.get_context("spawn").Queue()
261        child_nproc = 2
262        ctx = mp.spawn(
263            start_processes_zombie_test,
264            nprocs=1,
265            args=(entrypoint, mp_queue, self.log_dir(), child_nproc),
266            join=False,
267        )
268        total_processes = child_nproc + 1
269        pids = []
270        for _ in range(total_processes):
271            pids.append(mp_queue.get(timeout=120))
272        parent_pid = pids[0]
273        child_pids = pids[1:]
274
275        os.kill(parent_pid, signal.SIGTERM)
276        # Wait to give time for signal handlers to finish work
277        time.sleep(5)
278        for child_pid in child_pids:
279            # Killing parent should kill all children, we expect that each call to
280            # os.kill would raise OSError
281            with self.assertRaises(OSError):
282                os.kill(child_pid, 0)
283
284
285# tests incompatible with tsan or asan
286if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
287
288    class StartProcessesAsFuncTest(_StartProcessesTest):
289        def test_to_map(self):
290            local_world_size = 2
291            self.assertEqual(
292                {0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size)
293            )
294            self.assertEqual(
295                {0: Std.NONE, 1: Std.OUT}, to_map({1: Std.OUT}, local_world_size)
296            )
297            self.assertEqual(
298                {0: Std.ERR, 1: Std.OUT},
299                to_map({0: Std.ERR, 1: Std.OUT}, local_world_size),
300            )
301
302        def test_invalid_log_dir(self):
303            with tempfile.NamedTemporaryFile(dir=self.test_dir) as not_a_dir:
304                cases = {
305                    not_a_dir.name: NotADirectoryError,
306                }
307
308                for log_dir, expected_error in cases.items():
309                    with self.subTest(log_dir=log_dir, expected_error=expected_error):
310                        with self.assertRaises(expected_error):
311                            pc = None
312                            try:
313                                pc = start_processes(
314                                    name="echo",
315                                    entrypoint=echo1,
316                                    args={0: ("hello",)},
317                                    envs={0: {"RANK": "0"}},
318                                    logs_specs=DefaultLogsSpecs(log_dir=log_dir),
319                                )
320                            finally:
321                                if pc:
322                                    pc.close()
323
324        def test_args_env_len_mismatch(self):
325            cases = [
326                # 1 x args; 2 x envs
327                {
328                    "args": {0: ("hello",)},
329                    "envs": {0: {"RANK": "0"}, 1: {"RANK": "1"}},
330                },
331                # 2 x args; 1 x envs
332                {
333                    "args": {0: ("hello",), 1: ("world",)},
334                    "envs": {0: {"RANK": "0"}},
335                },
336            ]
337
338            for kwds in cases:
339                args = kwds["args"]
340                envs = kwds["envs"]
341                with self.subTest(args=args, envs=envs):
342                    with self.assertRaises(RuntimeError):
343                        start_processes(
344                            name="echo",
345                            entrypoint=echo1,
346                            args=args,
347                            envs=envs,
348                            logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
349                        )
350
351        def test_pcontext_wait(self):
352            pc = start_processes(
353                name="sleep",
354                entrypoint=time.sleep,
355                args={0: (1,)},
356                envs={0: {}},
357                logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
358                start_method="spawn",
359            )
360
361            self.assertIsNone(pc.wait(timeout=0.1, period=0.01))
362            self.assertIsNotNone(pc.wait(period=0.1))
363            self.assertTrue(pc._stderr_tail.stopped())
364            self.assertTrue(pc._stdout_tail.stopped())
365
366        def test_pcontext_wait_on_a_child_thread(self):
367            asyncio.run(asyncio.to_thread(self.test_pcontext_wait))
368
369        def test_multiprocess_context_close(self):
370            pc = start_processes(
371                name="sleep",
372                entrypoint=time.sleep,
373                args={0: (1,)},
374                envs={0: {}},
375                logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
376                start_method="spawn",
377            )
378
379            pids = pc.pids()
380            pc.close()
381            self.assert_pids_noexist(pids)
382            self.assertTrue(pc._stderr_tail.stopped())
383            self.assertTrue(pc._stdout_tail.stopped())
384
385        def test_function_with_tensor(self):
386            for start_method in self._start_methods:
387                pc = start_processes(
388                    name="dummy_compute",
389                    entrypoint=dummy_compute,
390                    args={0: ()},
391                    envs={0: {}},
392                    logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
393                    start_method=start_method,
394                )
395
396                results = pc.wait()
397                self.assert_pids_noexist(pc.pids())
398                for return_value in results.return_values.values():
399                    self.assertIsInstance(return_value, torch.Tensor)
400                    self.assertEqual((100, 100), return_value.shape)
401
402        def test_void_function(self):
403            for start_method in self._start_methods:
404                with self.subTest(start_method=start_method):
405                    pc = start_processes(
406                        name="echo",
407                        entrypoint=echo0,
408                        args={0: ("hello",), 1: ("world",)},
409                        envs={0: {}, 1: {}},
410                        logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
411                        start_method=start_method,
412                    )
413
414                    results = pc.wait(period=0.1)
415                    self.assertEqual({0: None, 1: None}, results.return_values)
416
417        @skip_but_pass_in_sandcastle_if(
418            TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan"
419        )
420        def test_function_large_ret_val(self):
421            # python multiprocessing.queue module uses pipes and actually PipedQueues
422            # This means that if a single object is greater than a pipe size
423            # the writer process will block until reader process will start
424            # reading the pipe.
425            # This test makes a worker fn to return huge output, around ~10 MB
426
427            size = 200000
428            for start_method in self._start_methods:
429                with self.subTest(start_method=start_method):
430                    pc = start_processes(
431                        logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
432                        name="echo",
433                        entrypoint=echo_large,
434                        args={0: (size,), 1: (size,), 2: (size,), 3: (size,)},
435                        envs={0: {}, 1: {}, 2: {}, 3: {}},
436                        start_method=start_method,
437                    )
438
439                    results = pc.wait(period=0.1)
440                    for i in range(pc.nprocs):
441                        self.assertEqual(size, len(results.return_values[i]))
442
443        def test_function_raise(self):
444            """
445            run 2x copies of echo2, raise an exception on the first
446            """
447            RAISE = True
448
449            for start_method in self._start_methods:
450                with self.subTest(start_method=start_method):
451                    log_dir = self.log_dir()
452                    pc = start_processes(
453                        name="echo",
454                        entrypoint=echo2,
455                        args={0: ("hello", RAISE), 1: ("world",)},
456                        envs={
457                            0: {"TORCHELASTIC_RUN_ID": "run_id"},
458                            1: {"TORCHELASTIC_RUN_ID": "run_id"},
459                        },
460                        logs_specs=DefaultLogsSpecs(log_dir=log_dir),
461                        start_method=start_method,
462                    )
463
464                    results = pc.wait(period=0.1)
465
466                    self.assert_pids_noexist(pc.pids())
467                    self.assertEqual(1, len(results.failures))
468                    self.assertFalse(results.return_values)
469
470                    failure = results.failures[0]
471                    error_file = failure.error_file
472                    error_file_data = failure.error_file_data
473
474                    self.assertEqual(1, failure.exitcode)
475                    self.assertEqual("<N/A>", failure.signal_name())
476                    self.assertEqual(pc.pids()[0], failure.pid)
477                    self.assertTrue(
478                        error_file.startswith(os.path.join(log_dir, "run_id_"))
479                    )
480                    self.assertTrue(error_file.endswith("attempt_0/0/error.json"))
481                    self.assertEqual(
482                        int(error_file_data["message"]["extraInfo"]["timestamp"]),
483                        int(failure.timestamp),
484                    )
485                    self.assertTrue(pc._stderr_tail.stopped())
486                    self.assertTrue(pc._stdout_tail.stopped())
487
488        def test_wait_for_all_child_procs_to_exit(self):
489            """
490            Tests that MultiprocessingContext actually waits for
491            the child process to exit (not just that the entrypoint fn has
492            finished running).
493            """
494
495            mpc = MultiprocessContext(
496                name="echo",
497                entrypoint=echo0,
498                args={},
499                envs={},
500                start_method="spawn",
501                logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
502            )
503
504            with mock.patch.object(
505                mpc, "_is_done", return_value=True
506            ), mock.patch.object(mpc, "_pc"), mock.patch.object(
507                mpc._pc, "join", side_effect=[True, False, False, True]
508            ) as mock_join:
509                mpc._poll()
510                self.assertEqual(4, mock_join.call_count)
511
512        @skip_but_pass_in_sandcastle_if(
513            NO_MULTIPROCESSING_SPAWN,
514            "Disabled for environments that \
515                        don't support multiprocessing with spawn start method",
516        )
517        def test_multiprocessing_context_poll_raises_exception(self):
518            mp_context = MultiprocessContext(
519                name="test_mp",
520                entrypoint=echo0,
521                args={0: (0, 1)},
522                envs={0: {}},
523                logs_specs=DefaultLogsSpecs(
524                    log_dir=self.log_dir(), redirects=Std.ALL, tee=Std.ALL
525                ),
526                start_method="spawn",
527            )
528            mp_context._pc = mock.Mock()
529            # Using mock since we cannot just set exitcode on process
530            mock_process = mock.Mock()
531            mock_process.exitcode = -1
532            mp_context._pc.processes = [mock_process]
533            e = mp.ProcessRaisedException(msg="test msg", error_index=0, error_pid=123)
534            mp_context._pc.join.side_effect = e
535            with mock.patch.object(mp_context, "close"):
536                run_result = mp_context._poll()
537                self.assertEqual(1, len(run_result.failures))
538                failure = run_result.failures[0]
539                self.assertEqual(
540                    "Signal 1 (SIGHUP) received by PID 123", failure.message
541                )
542
543    class StartProcessesAsBinaryTest(_StartProcessesTest):
544        ########################################
545        # start_processes as binary tests
546        ########################################
547
548        def test_subprocess_context_close(self):
549            pc = start_processes(
550                name="sleep",
551                entrypoint=bin("zombie_test.py"),
552                args={0: (1,)},
553                envs={0: {}},
554                logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
555            )
556
557            pids = pc.pids()
558            pc.close()
559            self.assert_pids_noexist(pids)
560
561        def test_binary_exit(self):
562            FAIL = 138
563            pc = start_processes(
564                name="echo",
565                entrypoint=bin("echo1.py"),
566                args={0: ("--exitcode", FAIL, "foo"), 1: ("--exitcode", 0, "bar")},
567                envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
568                logs_specs=DefaultLogsSpecs(
569                    log_dir=self.log_dir(),
570                    redirects={0: Std.ALL},
571                ),
572            )
573
574            results = pc.wait(period=0.1)
575
576            self.assertTrue(results.is_failed())
577            self.assertEqual(1, len(results.failures))
578
579            failure = results.failures[0]
580            self.assertEqual(138, failure.exitcode)
581            self.assertEqual("<N/A>", failure.signal_name())
582            self.assertEqual("<NONE>", failure.error_file_data["message"])
583            self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0])
584            self.assert_in_file([], results.stdouts[0])
585            self.assertFalse(results.stderrs[1])
586            self.assertFalse(results.stdouts[1])
587            self.assertTrue(pc._stderr_tail.stopped())
588            self.assertTrue(pc._stdout_tail.stopped())
589
590        def test_binary_raises(self):
591            pc = start_processes(
592                name="echo",
593                entrypoint=bin("echo2.py"),
594                args={0: ("--raises", "true", "foo"), 1: ("bar",)},
595                envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
596                logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
597            )
598
599            results = pc.wait(period=0.1)
600
601            self.assert_pids_noexist(pc.pids())
602            self.assertTrue(results.is_failed())
603            self.assertEqual(1, len(results.failures))
604
605            failure = results.failures[0]
606            self.assertEqual(1, failure.exitcode)
607            self.assertEqual("<NONE>", failure.error_file_data["message"])
608            self.assertEqual("<N/A>", failure.signal_name())
609
610        def test_binary_incorrect_entrypoint(self):
611            with self.assertRaises(FileNotFoundError):
612                start_processes(
613                    name="echo",
614                    entrypoint="does_not_exist.py",
615                    args={0: ("foo"), 1: ("bar",)},
616                    envs={0: {}, 1: {}},
617                    logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()),
618                )
619
620        def test_validate_full_rank(self):
621            with self.assertRaises(RuntimeError):
622                _validate_full_rank({}, 10, "")
623
624
625# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
626if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
627
628    class StartProcessesListAsFuncTest(_StartProcessesTest):
629        def test_function(self):
630            for start_method, redirs in product(
631                self._start_methods, redirects_oss_test()
632            ):
633                with self.subTest(start_method=start_method, redirs=redirs):
634                    pc = start_processes(
635                        name="echo",
636                        entrypoint=echo1,
637                        args={0: ("hello",), 1: ("hello",)},
638                        envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
639                        logs_specs=DefaultLogsSpecs(
640                            log_dir=self.log_dir(),
641                            redirects=redirs,
642                        ),
643                        start_method=start_method,
644                    )
645
646                    results = pc.wait(period=0.1)
647                    nprocs = pc.nprocs
648
649                    self.assert_pids_noexist(pc.pids())
650                    self.assertEqual(
651                        {i: f"hello_{i}" for i in range(nprocs)}, results.return_values
652                    )
653
654                    for i in range(nprocs):
655                        if redirs & Std.OUT != Std.OUT:
656                            self.assertFalse(results.stdouts[i])
657                        if redirs & Std.ERR != Std.ERR:
658                            self.assertFalse(results.stderrs[i])
659                        if redirs & Std.OUT == Std.OUT:
660                            self.assert_in_file(
661                                [f"hello stdout from {i}"], results.stdouts[i]
662                            )
663                        if redirs & Std.ERR == Std.ERR:
664                            self.assert_in_file(
665                                [f"hello stderr from {i}"], results.stderrs[i]
666                            )
667
668    class StartProcessesListAsBinaryTest(_StartProcessesTest):
669        ########################################
670        # start_processes as binary tests
671        ########################################
672        def test_binary(self):
673            for redirs in redirects_oss_test():
674                with self.subTest(redirs=redirs):
675                    pc = start_processes(
676                        name="echo",
677                        entrypoint=bin("echo1.py"),
678                        args={0: ("hello",), 1: ("hello",)},
679                        envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
680                        logs_specs=DefaultLogsSpecs(
681                            log_dir=self.log_dir(),
682                            redirects=redirs,
683                        ),
684                        log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
685                    )
686
687                    results = pc.wait(period=0.1)
688
689                    self.assert_pids_noexist(pc.pids())
690                    # currently binaries return {rank: None}
691                    self.assertEqual(2, len(results.return_values))
692                    self.assertFalse(results.is_failed())
693
694                    nprocs = pc.nprocs
695                    for i in range(nprocs):
696                        if redirs & Std.OUT != Std.OUT:
697                            self.assertFalse(results.stdouts[i])
698                        if redirs & Std.ERR != Std.ERR:
699                            self.assertFalse(results.stderrs[i])
700                        if redirs & Std.OUT == Std.OUT:
701                            self.assert_in_file(
702                                [f"hello stdout from {i}"], results.stdouts[i]
703                            )
704                        if redirs & Std.ERR == Std.ERR:
705                            self.assert_in_file(
706                                [f"hello stderr from {i}"], results.stderrs[i]
707                            )
708
709        def test_binary_redirect_and_tee(self):
710            pc = start_processes(
711                name="trainer",
712                entrypoint=bin("echo1.py"),
713                args={0: ("hello",), 1: ("world",)},
714                envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
715                logs_specs=DefaultLogsSpecs(
716                    log_dir=self.log_dir(),
717                    redirects={0: Std.ERR, 1: Std.NONE},
718                    tee={0: Std.OUT, 1: Std.ERR},
719                ),
720                log_line_prefixes={0: "[rank0]:", 1: "[rank1]:"},
721                start_method="spawn",
722            )
723
724            result = pc.wait()
725
726            self.assertFalse(result.is_failed())
727            self.assert_in_file(["hello stdout from 0"], pc.stdouts[0])
728            self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
729            self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
730            self.assertFalse(pc.stdouts[1])
731            self.assertTrue(pc._stderr_tail.stopped())
732            self.assertTrue(pc._stdout_tail.stopped())
733
734
735# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
736if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_CI):
737
738    class StartProcessesNotCIAsFuncTest(_StartProcessesTest):
739        @skip_if_pytest
740        def test_wrap_bad(self):
741            none = ""
742            stdout_log = os.path.join(self.test_dir, "stdout.log")
743            stderr_log = os.path.join(self.test_dir, "stderr.log")
744            redirs = [
745                (none, none),
746                (none, stderr_log),
747                (stdout_log, none),
748                (stdout_log, stderr_log),
749            ]
750
751            for stdout_redir, stderr_redir in redirs:
752                queue = multiprocessing.SimpleQueue()
753                worker_finished_event_mock = mock.Mock()
754                _wrap(
755                    local_rank=0,
756                    fn=echo1,
757                    args={0: ("hello",)},
758                    envs={0: {"RANK": "0"}},
759                    stdout_redirects={0: stdout_redir},
760                    stderr_redirects={0: stderr_redir},
761                    ret_vals={0: queue},
762                    queue_finished_reading_event=worker_finished_event_mock,
763                )
764                self.assertEqual("hello_0", queue.get())
765                if stdout_redir:
766                    self.assert_in_file(["hello stdout from 0"], stdout_log)
767                if stderr_redir:
768                    self.assert_in_file(["hello stderr from 0"], stderr_log)
769                worker_finished_event_mock.wait.assert_called_once()
770
771        def test_function_redirect_and_tee(self):
772            for start_method in self._start_methods:
773                with self.subTest(start_method=start_method):
774                    pc = start_processes(
775                        name="trainer",
776                        entrypoint=echo1,
777                        args={0: ("hello",), 1: ("world",)},
778                        envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
779                        logs_specs=DefaultLogsSpecs(
780                            log_dir=self.log_dir(),
781                            redirects={0: Std.ERR, 1: Std.NONE},
782                            tee={0: Std.OUT, 1: Std.ERR},
783                        ),
784                        start_method="spawn",
785                    )
786
787                    result = pc.wait()
788
789                    self.assertFalse(result.is_failed())
790                    self.assert_in_file(["hello stdout from 0"], pc.stdouts[0])
791                    self.assert_in_file(["hello stderr from 0"], pc.stderrs[0])
792                    self.assert_in_file(["world stderr from 1"], pc.stderrs[1])
793                    self.assertFalse(pc.stdouts[1])
794                    self.assertTrue(pc._stderr_tail.stopped())
795                    self.assertTrue(pc._stdout_tail.stopped())
796
797        def test_function(self):
798            for start_method, redirs in product(self._start_methods, redirects_all()):
799                with self.subTest(start_method=start_method, redirs=redirs):
800                    pc = start_processes(
801                        name="echo",
802                        entrypoint=echo1,
803                        args={0: ("hello",), 1: ("hello",)},
804                        envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
805                        start_method=start_method,
806                        logs_specs=DefaultLogsSpecs(
807                            log_dir=self.log_dir(),
808                            redirects=redirs,
809                        ),
810                    )
811
812                    results = pc.wait(period=0.1)
813                    nprocs = pc.nprocs
814
815                    self.assert_pids_noexist(pc.pids())
816                    self.assertEqual(
817                        {i: f"hello_{i}" for i in range(nprocs)}, results.return_values
818                    )
819
820                    for i in range(nprocs):
821                        if redirs & Std.OUT != Std.OUT:
822                            self.assertFalse(results.stdouts[i])
823                        if redirs & Std.ERR != Std.ERR:
824                            self.assertFalse(results.stderrs[i])
825                        if redirs & Std.OUT == Std.OUT:
826                            self.assert_in_file(
827                                [f"hello stdout from {i}"], results.stdouts[i]
828                            )
829                        if redirs & Std.ERR == Std.ERR:
830                            self.assert_in_file(
831                                [f"hello stderr from {i}"], results.stderrs[i]
832                            )
833
834        def test_function_exit(self):
835            """
836            run 2x copies of echo1 fail (exit) the first
837            functions that exit from python do not generate an error file
838            (even if they are decorated with @record)
839            """
840
841            FAIL = 138
842            for start_method in self._start_methods:
843                with self.subTest(start_method=start_method):
844                    pc = start_processes(
845                        name="echo",
846                        entrypoint=echo1,
847                        args={0: ("hello", FAIL), 1: ("hello",)},
848                        envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
849                        logs_specs=DefaultLogsSpecs(
850                            log_dir=self.log_dir(),
851                            redirects={0: Std.ERR},
852                        ),
853                        start_method=start_method,
854                    )
855
856                    results = pc.wait(period=0.1)
857
858                    self.assert_pids_noexist(pc.pids())
859                    self.assertTrue(results.is_failed())
860                    self.assertEqual(1, len(results.failures))
861                    self.assertFalse(results.return_values)
862
863                    failure = results.failures[0]
864                    error_file = failure.error_file
865
866                    self.assertEqual(FAIL, failure.exitcode)
867                    self.assertEqual("<N/A>", failure.signal_name())
868                    self.assertEqual(pc.pids()[0], failure.pid)
869                    self.assertEqual("<N/A>", error_file)
870                    self.assertEqual(
871                        "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html",
872                        failure.message,
873                    )
874                    self.assertLessEqual(failure.timestamp, int(time.time()))
875
876                    self.assert_in_file([f"exit {FAIL} from 0"], results.stderrs[0])
877                    self.assertFalse(results.stdouts[0])
878                    self.assertFalse(results.stderrs[1])
879                    self.assertFalse(results.stdouts[1])
880                    self.assertTrue(pc._stderr_tail.stopped())
881                    self.assertTrue(pc._stdout_tail.stopped())
882
883        def test_no_zombie_process_function(self):
884            signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
885            for s in signals:
886                self._test_zombie_workflow(wait_fn, s)
887
888    class StartProcessesNotCIAsBinaryTest(_StartProcessesTest):
889        def test_binary_signal(self):
890            pc = start_processes(
891                name="echo",
892                entrypoint=bin("echo3.py"),
893                args={0: ("--segfault", "true", "foo"), 1: ("bar",)},
894                envs={0: {"RANK": "0"}, 1: {"RANK": "1"}},
895                logs_specs=DefaultLogsSpecs(
896                    log_dir=self.log_dir(),
897                ),
898            )
899
900            results = pc.wait(period=0.1)
901
902            self.assert_pids_noexist(pc.pids())
903            self.assertTrue(results.is_failed())
904            self.assertEqual(1, len(results.failures))
905
906            failure = results.failures[0]
907            self.assertNotEqual(signal.SIGSEGV, failure.exitcode)
908            if TEST_WITH_ASAN or TEST_WITH_TSAN:
909                # ASAN/TSAN exit code is 1.
910                self.assertEqual("<N/A>", failure.signal_name())
911            else:
912                self.assertEqual("SIGSEGV", failure.signal_name())
913            self.assertEqual("<NONE>", failure.error_file_data["message"])
914
915        def test_no_zombie_process_binary(self):
916            signals = [signal.SIGTERM, signal.SIGINT, signal.SIGHUP, signal.SIGQUIT]
917            for s in signals:
918                self._test_zombie_workflow(bin("zombie_test.py"), s)
919
920    class ForkServerTest(
921        StartProcessesAsFuncTest,
922        StartProcessesListAsFuncTest,
923        StartProcessesNotCIAsFuncTest,
924    ):
925        def setUp(self):
926            super().setUp()
927            self._start_methods = ["forkserver"]
928            self.orig_paralell_env_val = os.environ.get(mp.ENV_VAR_PARALLEL_START)
929            os.environ[mp.ENV_VAR_PARALLEL_START] = "1"
930
931        def tearDown(self):
932            super().tearDown()
933            if self.orig_paralell_env_val is None:
934                del os.environ[mp.ENV_VAR_PARALLEL_START]
935            else:
936                os.environ[mp.ENV_VAR_PARALLEL_START] = self.orig_paralell_env_val
937
938
939if __name__ == "__main__":
940    run_tests()
941