xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/timer/local_timer_example.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import logging
10import multiprocessing as mp
11import signal
12import time
13
14import torch.distributed.elastic.timer as timer
15import torch.multiprocessing as torch_mp
16from torch.testing._internal.common_utils import (
17    IS_MACOS,
18    IS_WINDOWS,
19    run_tests,
20    skip_but_pass_in_sandcastle_if,
21    TEST_WITH_DEV_DBG_ASAN,
22    TestCase,
23)
24
25
26logging.basicConfig(
27    level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s"
28)
29
30
31def _happy_function(rank, mp_queue):
32    timer.configure(timer.LocalTimerClient(mp_queue))
33    with timer.expires(after=1):
34        time.sleep(0.5)
35
36
37def _stuck_function(rank, mp_queue):
38    timer.configure(timer.LocalTimerClient(mp_queue))
39    with timer.expires(after=1):
40        time.sleep(5)
41
42
43# timer is not supported on macos or windows
44if not (IS_WINDOWS or IS_MACOS):
45
46    class LocalTimerExample(TestCase):
47        """
48        Demonstrates how to use LocalTimerServer and LocalTimerClient
49        to enforce expiration of code-blocks.
50
51        Since torch multiprocessing's ``start_process`` method currently
52        does not take the multiprocessing context as parameter argument
53        there is no way to create the mp.Queue in the correct
54        context BEFORE spawning child processes. Once the ``start_process``
55        API is changed in torch, then re-enable ``test_torch_mp_example``
56        unittest. As of now this will SIGSEGV.
57        """
58
59        @skip_but_pass_in_sandcastle_if(
60            TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible"
61        )
62        def test_torch_mp_example(self):
63            # in practice set the max_interval to a larger value (e.g. 60 seconds)
64            mp_queue = mp.get_context("spawn").Queue()
65            server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
66            server.start()
67
68            world_size = 8
69
70            # all processes should complete successfully
71            # since start_process does NOT take context as parameter argument yet
72            # this method WILL FAIL (hence the test is disabled)
73            torch_mp.spawn(
74                fn=_happy_function, args=(mp_queue,), nprocs=world_size, join=True
75            )
76
77            with self.assertRaises(Exception):
78                # torch.multiprocessing.spawn kills all sub-procs
79                # if one of them gets killed
80                torch_mp.spawn(
81                    fn=_stuck_function, args=(mp_queue,), nprocs=world_size, join=True
82                )
83
84            server.stop()
85
86        @skip_but_pass_in_sandcastle_if(
87            TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible"
88        )
89        def test_example_start_method_spawn(self):
90            self._run_example_with(start_method="spawn")
91
92        # @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible")
93        # def test_example_start_method_forkserver(self):
94        #     self._run_example_with(start_method="forkserver")
95
96        def _run_example_with(self, start_method):
97            spawn_ctx = mp.get_context(start_method)
98            mp_queue = spawn_ctx.Queue()
99            server = timer.LocalTimerServer(mp_queue, max_interval=0.01)
100            server.start()
101
102            world_size = 8
103            processes = []
104            for i in range(0, world_size):
105                if i % 2 == 0:
106                    p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue))
107                else:
108                    p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue))
109                p.start()
110                processes.append(p)
111
112            for i in range(0, world_size):
113                p = processes[i]
114                p.join()
115                if i % 2 == 0:
116                    self.assertEqual(-signal.SIGKILL, p.exitcode)
117                else:
118                    self.assertEqual(0, p.exitcode)
119
120            server.stop()
121
122
123if __name__ == "__main__":
124    run_tests()
125