1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import logging 10import multiprocessing as mp 11import signal 12import time 13 14import torch.distributed.elastic.timer as timer 15import torch.multiprocessing as torch_mp 16from torch.testing._internal.common_utils import ( 17 IS_MACOS, 18 IS_WINDOWS, 19 run_tests, 20 skip_but_pass_in_sandcastle_if, 21 TEST_WITH_DEV_DBG_ASAN, 22 TestCase, 23) 24 25 26logging.basicConfig( 27 level=logging.INFO, format="[%(levelname)s] %(asctime)s %(module)s: %(message)s" 28) 29 30 31def _happy_function(rank, mp_queue): 32 timer.configure(timer.LocalTimerClient(mp_queue)) 33 with timer.expires(after=1): 34 time.sleep(0.5) 35 36 37def _stuck_function(rank, mp_queue): 38 timer.configure(timer.LocalTimerClient(mp_queue)) 39 with timer.expires(after=1): 40 time.sleep(5) 41 42 43# timer is not supported on macos or windows 44if not (IS_WINDOWS or IS_MACOS): 45 46 class LocalTimerExample(TestCase): 47 """ 48 Demonstrates how to use LocalTimerServer and LocalTimerClient 49 to enforce expiration of code-blocks. 50 51 Since torch multiprocessing's ``start_process`` method currently 52 does not take the multiprocessing context as parameter argument 53 there is no way to create the mp.Queue in the correct 54 context BEFORE spawning child processes. Once the ``start_process`` 55 API is changed in torch, then re-enable ``test_torch_mp_example`` 56 unittest. As of now this will SIGSEGV. 57 """ 58 59 @skip_but_pass_in_sandcastle_if( 60 TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible" 61 ) 62 def test_torch_mp_example(self): 63 # in practice set the max_interval to a larger value (e.g. 60 seconds) 64 mp_queue = mp.get_context("spawn").Queue() 65 server = timer.LocalTimerServer(mp_queue, max_interval=0.01) 66 server.start() 67 68 world_size = 8 69 70 # all processes should complete successfully 71 # since start_process does NOT take context as parameter argument yet 72 # this method WILL FAIL (hence the test is disabled) 73 torch_mp.spawn( 74 fn=_happy_function, args=(mp_queue,), nprocs=world_size, join=True 75 ) 76 77 with self.assertRaises(Exception): 78 # torch.multiprocessing.spawn kills all sub-procs 79 # if one of them gets killed 80 torch_mp.spawn( 81 fn=_stuck_function, args=(mp_queue,), nprocs=world_size, join=True 82 ) 83 84 server.stop() 85 86 @skip_but_pass_in_sandcastle_if( 87 TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible" 88 ) 89 def test_example_start_method_spawn(self): 90 self._run_example_with(start_method="spawn") 91 92 # @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test is asan incompatible") 93 # def test_example_start_method_forkserver(self): 94 # self._run_example_with(start_method="forkserver") 95 96 def _run_example_with(self, start_method): 97 spawn_ctx = mp.get_context(start_method) 98 mp_queue = spawn_ctx.Queue() 99 server = timer.LocalTimerServer(mp_queue, max_interval=0.01) 100 server.start() 101 102 world_size = 8 103 processes = [] 104 for i in range(0, world_size): 105 if i % 2 == 0: 106 p = spawn_ctx.Process(target=_stuck_function, args=(i, mp_queue)) 107 else: 108 p = spawn_ctx.Process(target=_happy_function, args=(i, mp_queue)) 109 p.start() 110 processes.append(p) 111 112 for i in range(0, world_size): 113 p = processes[i] 114 p.join() 115 if i % 2 == 0: 116 self.assertEqual(-signal.SIGKILL, p.exitcode) 117 else: 118 self.assertEqual(0, p.exitcode) 119 120 server.stop() 121 122 123if __name__ == "__main__": 124 run_tests() 125