xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/agent/server/test/api_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10
11import signal
12import unittest
13import uuid
14from multiprocessing.pool import ThreadPool
15from typing import Any, Dict, List
16from unittest.mock import call, patch
17
18import torch.distributed as dist
19import torch.distributed.elastic.rendezvous.registry as rdzv_registry
20from torch.distributed.elastic.agent.server.api import (
21    _get_fq_hostname,
22    _RoleInstanceInfo,
23    RunResult,
24    SimpleElasticAgent,
25    Worker,
26    WorkerGroup,
27    WorkerSpec,
28    WorkerState,
29)
30from torch.distributed.elastic.multiprocessing import SignalException
31from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
32from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
33from torch.distributed.elastic.rendezvous.api import RendezvousGracefulExitError
34from torch.distributed.elastic.utils.distributed import get_free_port
35from torch.testing._internal.common_utils import run_tests
36
37
38def do_nothing():
39    pass
40
41
42class WorkerStateTest(unittest.TestCase):
43    def test_is_running(self):
44        for state in WorkerState:
45            if state == WorkerState.HEALTHY or state == WorkerState.UNHEALTHY:
46                self.assertTrue(WorkerState.is_running(state))
47            else:
48                self.assertFalse(WorkerState.is_running(state))
49
50
51class WorkerGroupTest(unittest.TestCase):
52    def test_worker_group_constructor(self):
53        spec = WorkerSpec(
54            role="test_trainer",
55            local_world_size=4,
56            fn=do_nothing,
57            args=(),
58            rdzv_handler=None,
59            max_restarts=50,
60            monitor_interval=0.1,
61        )
62        worker_group = WorkerGroup(spec)
63
64        self.assertEqual(WorkerState.INIT, worker_group.state)
65
66        workers = worker_group.workers
67        self.assertEqual(4, len(workers))
68
69        # validate full, consecutive local ranks
70        self.assertSetEqual(set(range(4)), {w.local_rank for w in workers})
71
72        # global_rank, world_size are assigned after rdzv
73        # id is assigned after starting worker (by the agent)
74        # validate there are None
75        for w in workers:
76            self.assertEqual(-1, w.global_rank)
77            self.assertEqual(-1, w.world_size)
78            self.assertEqual(None, w.id)
79
80        # rank and store are assigned after rdzv; validate that they are None
81        self.assertIsNone(worker_group.group_rank)
82        self.assertIsNone(worker_group.store)
83
84
85class RoleInstanceInfoTest(unittest.TestCase):
86    def test_compare(self):
87        agent_role1 = _RoleInstanceInfo("role", 1, 10)
88        agent_role2 = _RoleInstanceInfo("role", 2, 10)
89        self.assertEqual(1, _RoleInstanceInfo.compare(agent_role2, agent_role1))
90        agent_role1 = _RoleInstanceInfo("role1", 1, 10)
91        agent_role2 = _RoleInstanceInfo("role2", 2, 10)
92        self.assertEqual(-1, _RoleInstanceInfo.compare(agent_role1, agent_role2))
93        agent_role1 = _RoleInstanceInfo("role1", 1, 10)
94        agent_role2 = _RoleInstanceInfo("role2", 1, 10)
95        self.assertEqual(-1, _RoleInstanceInfo.compare(agent_role1, agent_role2))
96
97    def test_serde(self):
98        agent_role = _RoleInstanceInfo("role", 1, 10)
99        str_data = agent_role.serialize()
100        actual_agent_role = _RoleInstanceInfo.deserialize(str_data)
101        self.assertEqual(agent_role.role, actual_agent_role.role)
102        self.assertEqual(agent_role.rank, actual_agent_role.rank)
103        self.assertEqual(
104            agent_role.local_world_size, actual_agent_role.local_world_size
105        )
106
107    def test_find_boundaries(self):
108        role_infos = [
109            _RoleInstanceInfo("trainer", 1, 1),
110            _RoleInstanceInfo("trainer", 2, 2),
111            _RoleInstanceInfo("trainer", 3, 3),
112            _RoleInstanceInfo("parameter_server", 4, 5),
113            _RoleInstanceInfo("parameter_server", 0, 4),
114        ]
115        start_idx, end_idx = _RoleInstanceInfo.find_role_boundaries(
116            role_infos, "trainer"
117        )
118        self.assertEqual(start_idx, 0)
119        self.assertEqual(end_idx, 2)
120
121
122class TestAgent(SimpleElasticAgent):
123    def __init__(self, spec):
124        super().__init__(spec)
125        self.stop_workers_call_count = 0
126        self.start_workers_call_count = 0
127
128    def _stop_workers(
129        self, worker_group: WorkerGroup, is_restart: bool = False
130    ) -> None:
131        # workers are fake, nothing to stop; just clear the rdzv info
132        worker_group.group_rank = None
133        worker_group.group_world_size = None
134        self.stop_workers_call_count += 1
135
136    def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
137        # crate fake workers; make worker id equal to global rank
138        ids = {}
139        for worker in worker_group.workers:
140            ids[worker.local_rank] = worker.global_rank
141        self.start_workers_call_count += 1
142        return ids
143
144    def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
145        raise NotImplementedError("mock this method")
146
147    def _shutdown(self):
148        pass
149
150
151def monres(state: WorkerState):
152    if state == WorkerState.SUCCEEDED:
153        return RunResult(state=state, return_values={0: 0}, failures={})
154    elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
155        pf = ProcessFailure(local_rank=0, pid=999, exitcode=1, error_file="<none>")
156        return RunResult(state=state, return_values={}, failures={0: pf})
157    else:
158        return RunResult(state=state)
159
160
161class SimpleElasticAgentTest(unittest.TestCase):
162    def _get_worker_spec(
163        self,
164        max_restarts=1,
165        monitor_interval=0.1,
166        role="test_trainer",
167        local_world_size=8,
168        local_addr=None,
169    ):
170        run_id = str(uuid.uuid4().int)
171        port = get_free_port()
172        if local_addr is None:
173            endpoint = f"127.0.0.1:{port}"
174        else:
175            endpoint = f"{local_addr}:{port}"
176
177        rdzv_params = RendezvousParameters(
178            backend="static",
179            endpoint=endpoint,
180            run_id=run_id,
181            min_nodes=1,
182            max_nodes=1,
183            rank=0,
184        )
185        rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
186        spec = WorkerSpec(
187            role=role,
188            local_world_size=local_world_size,
189            fn=do_nothing,
190            args=(),
191            rdzv_handler=rdzv_handler,
192            max_restarts=max_restarts,
193            monitor_interval=monitor_interval,
194            local_addr=local_addr,
195        )
196        return spec
197
198    def test_agent_constructor(self):
199        spec = self._get_worker_spec(max_restarts=1)
200        agent = TestAgent(spec)
201        worker_group = agent.get_worker_group()
202        self.assertEqual(WorkerState.INIT, worker_group.state)
203        self.assertEqual(spec.max_restarts, agent._remaining_restarts)
204
205    @patch("torch.distributed.elastic.agent.server.api.put_metric")
206    def test_record_flakiness_metric(self, put_metric_mock):
207        spec = self._get_worker_spec(max_restarts=1)
208        agent = TestAgent(spec)
209        agent._record_flakiness_metric()
210        put_metric_mock.assert_called_with("workers.test_trainer.flakiness", 0)
211        agent._worker_group.spec.max_restarts = 10
212        agent._remaining_restarts = 3
213        agent._record_flakiness_metric()
214        put_metric_mock.assert_called_with("workers.test_trainer.flakiness", 63)
215
216    @patch("torch.distributed.elastic.agent.server.api.put_metric")
217    def test_record_flakiness_metric_zero_restarts(self, put_metric_mock):
218        spec = self._get_worker_spec(max_restarts=1)
219        spec.max_restarts = 0
220        agent = TestAgent(spec)
221        agent._record_flakiness_metric()
222        put_metric_mock.assert_called_with("workers.test_trainer.flakiness", 0)
223
224    @patch("torch.distributed.elastic.agent.server.api.put_metric")
225    def test_record_flakiness_metric_user_exception(self, put_metric_mock):
226        spec = self._get_worker_spec(max_restarts=1)
227        agent = TestAgent(spec)
228        agent._record_flakiness_metric(True)
229        put_metric_mock.assert_called_with("workers.test_trainer.flakiness", 100)
230
231    @patch.object(TestAgent, "_invoke_run")
232    @patch.object(TestAgent, "_record_metrics")
233    @patch.object(TestAgent, "_record_worker_events")
234    @patch.object(TestAgent, "_shutdown")
235    def test_invoke_run(
236        self, shutdown_mock, record_events_mock, record_metrics_mock, invoke_run_mock
237    ):
238        spec = self._get_worker_spec(max_restarts=1)
239        agent = TestAgent(spec)
240        agent.run()
241        invoke_run_mock.assert_called_once()
242        record_metrics_mock.assert_called_once()
243        record_events_mock.assert_called_once()
244        shutdown_mock.assert_called_once()
245
246    @patch("torch.distributed.elastic.agent.server.api.put_metric")
247    def test_record_metrics_success_no_retries(self, put_metric_mock):
248        spec = self._get_worker_spec(max_restarts=1)
249        agent = TestAgent(spec)
250        group_result = RunResult({}, {})
251        agent._record_metrics(group_result)
252        calls = self._get_record_metrics_test_calls(success_no_retries=1)
253        put_metric_mock.assert_has_calls(calls, any_order=True)
254
255    @patch("torch.distributed.elastic.agent.server.api.put_metric")
256    def test_record_metrics_success_with_retries(self, put_metric_mock):
257        spec = self._get_worker_spec(max_restarts=10)
258        agent = TestAgent(spec)
259        agent._remaining_restarts = 2
260        group_result = RunResult({}, {})
261        agent._record_metrics(group_result)
262        calls = self._get_record_metrics_test_calls(success_with_retries=1)
263        put_metric_mock.assert_has_calls(calls, any_order=True)
264
265    @patch("torch.distributed.elastic.agent.server.api.put_metric")
266    def test_record_metrics_failed_with_retries(self, put_metric_mock):
267        spec = self._get_worker_spec(max_restarts=10)
268        agent = TestAgent(spec)
269        agent._remaining_restarts = 2
270        group_result = RunResult(
271            state=WorkerState.FAILED, return_values={}, failures={0: 0}
272        )
273        agent._record_metrics(group_result)
274        calls = self._get_record_metrics_test_calls(failed_with_retries=1)
275        put_metric_mock.assert_has_calls(calls, any_order=True)
276
277    @patch("torch.distributed.elastic.agent.server.api.put_metric")
278    def test_record_metrics_failed_no_retries(self, put_metric_mock):
279        spec = self._get_worker_spec(max_restarts=10)
280        agent = TestAgent(spec)
281        group_result = RunResult(
282            state=WorkerState.FAILED, return_values={}, failures={0: 0}
283        )
284        agent._record_metrics(group_result)
285        calls = self._get_record_metrics_test_calls(failed_no_retries=1)
286        put_metric_mock.assert_has_calls(calls, any_order=True)
287
288    def _get_record_metrics_test_calls(
289        self,
290        success_with_retries=0,
291        success_no_retries=0,
292        failed_with_retries=0,
293        failed_no_retries=0,
294    ):
295        calls = [
296            call("workers.test_trainer.run_success_with_retries", success_with_retries),
297            call("workers.test_trainer.run_success_no_retries", success_no_retries),
298            call("workers.test_trainer.run_failed_with_retries", failed_with_retries),
299            call("workers.test_trainer.run_failed_no_retries", failed_no_retries),
300        ]
301        return calls
302
303    def test_rendezvous(self):
304        hostname = _get_fq_hostname()
305        spec = self._get_worker_spec(max_restarts=1, local_addr=hostname)
306        agent = TestAgent(spec)
307        worker_group = agent.get_worker_group()
308        agent._rendezvous(worker_group)
309
310        # single agent rdzv
311        self.assertEqual(1, worker_group.group_world_size)
312        self.assertEqual(0, worker_group.group_rank)
313
314        self.assertEqual(hostname, worker_group.master_addr)
315        self.assertTrue(worker_group.master_port > 0)
316
317        rank_set = {w.global_rank for w in worker_group.workers}
318        for w in worker_group.workers:
319            self.assertIsNone(w.id)
320            local_world_size = spec.local_world_size
321            group_world_size = worker_group.group_world_size
322            group_rank = worker_group.group_rank
323
324            self.assertEqual(local_world_size * group_world_size, w.world_size)
325            self.assertEqual(
326                local_world_size * group_rank + w.local_rank, w.global_rank
327            )
328            self.assertSetEqual(set(range(w.world_size)), rank_set)
329
330    def test_rendezvous_default_master_addr(self):
331        hostname = _get_fq_hostname()
332        spec = self._get_worker_spec(max_restarts=1, local_addr=hostname)
333        agent = TestAgent(spec)
334        worker_group = agent.get_worker_group()
335        agent._rendezvous(worker_group)
336
337        self.assertEqual(_get_fq_hostname(), worker_group.master_addr)
338        self.assertGreater(worker_group.master_port, 0)
339
340    def test_rendezvous_master_addr_with_local_addr(self):
341        spec_local_addr = "127.0.0.1"
342        spec = self._get_worker_spec(max_restarts=1, local_addr=spec_local_addr)
343        agent = TestAgent(spec)
344        worker_group = agent.get_worker_group()
345        agent._rendezvous(worker_group)
346
347        self.assertNotEqual(_get_fq_hostname(), worker_group.master_addr)
348        self.assertEqual(spec_local_addr, worker_group.master_addr)
349        self.assertGreater(worker_group.master_port, 0)
350
351    def test_initialize_workers(self):
352        spec = self._get_worker_spec(max_restarts=1)
353        agent = TestAgent(spec)
354        worker_group = agent.get_worker_group()
355        agent._initialize_workers(worker_group)
356
357        self.assertEqual(WorkerState.HEALTHY, worker_group.state)
358        for i in range(spec.local_world_size):
359            worker = worker_group.workers[i]
360            self.assertEqual(worker.id, worker.global_rank)
361
362    def test_restart_workers(self):
363        spec = self._get_worker_spec()
364        agent = TestAgent(spec)
365        worker_group = agent.get_worker_group()
366
367        num_restarts = 3
368        for _ in range(0, num_restarts):
369            agent._restart_workers(worker_group)
370            self.assertEqual(WorkerState.HEALTHY, worker_group.state)
371
372            # test_rendezvous and test_initialize_workers
373            # already validates the correctness of these fields
374            # simply validate that they are not None
375            # (e.g. that they get assigned)
376            self.assertIsNotNone(worker_group.group_rank)
377            self.assertIsNotNone(worker_group.group_world_size)
378            for w in worker_group.workers:
379                self.assertIsNotNone(w.id)
380                self.assertIsNotNone(w.global_rank)
381                self.assertIsNotNone(w.world_size)
382
383        self.assertEqual(num_restarts, agent.start_workers_call_count)
384        self.assertEqual(num_restarts, agent.stop_workers_call_count)
385
386    @patch.object(
387        TestAgent,
388        "_monitor_workers",
389        side_effect=[
390            monres(WorkerState.HEALTHY),
391            monres(WorkerState.HEALTHY),
392            monres(WorkerState.SUCCEEDED),
393        ],
394    )
395    @patch.object(TestAgent, "_record_worker_events")
396    def test_run_happy_path(self, record_events_mock, mock_monitor_workers):
397        # worker starts
398        # is always healthy
399        # then succeeds
400        max_restarts = 10
401        spec = self._get_worker_spec(max_restarts)
402        agent = TestAgent(spec)
403
404        agent.run()
405
406        # no failure, no membership changes -> no retries
407        self.assertEqual(max_restarts, agent._remaining_restarts)
408        record_events_mock.assert_called_once()
409
410    @patch.object(TestAgent, "_initialize_workers", side_effect=RuntimeError())
411    def test_run_initialization_failure(self, mock_initialize_workers):
412        spec = self._get_worker_spec()
413        agent = TestAgent(spec)
414        worker_group = agent._worker_group
415
416        with self.assertRaises(RuntimeError):
417            agent.run()
418
419        self.assertEqual(WorkerState.INIT, worker_group.state)
420
421    def test_run_max_retries_exceeded(self):
422        for restartable_state in [
423            monres(WorkerState.FAILED),
424            monres(WorkerState.UNHEALTHY),
425        ]:
426            with patch.object(
427                TestAgent, "_monitor_workers", return_value=restartable_state
428            ) as mock_monitor_workers:
429                spec = self._get_worker_spec(max_restarts=3, monitor_interval=0.1)
430                agent = TestAgent(spec)
431                worker_group = agent._worker_group
432
433                agent.run()
434                self.assertEqual(WorkerState.FAILED, worker_group.state)
435                self.assertEqual(0, agent._remaining_restarts)
436                # one monitor call for each retry + one to monitor the last retry
437                self.assertEqual(spec.max_restarts + 1, mock_monitor_workers.call_count)
438
439    @patch.object(
440        TestAgent,
441        "_monitor_workers",
442        side_effect=[
443            monres(WorkerState.HEALTHY),
444            monres(WorkerState.HEALTHY),
445            monres(WorkerState.HEALTHY),
446            monres(WorkerState.SUCCEEDED),
447        ],
448    )
449    @patch.object(RendezvousHandler, "num_nodes_waiting", side_effect=[1, 1, 0])
450    @patch.object(TestAgent, "_record_worker_events")
451    def test_run_membership_change(
452        self, record_events_mock, mock_num_nodes_waiting, mock_monitor_workers
453    ):
454        spec = self._get_worker_spec(max_restarts=1, monitor_interval=0.1)
455        agent = TestAgent(spec)
456        worker_group = agent._worker_group
457
458        agent.run()
459        self.assertEqual(WorkerState.SUCCEEDED, worker_group.state)
460        record_events_mock.assert_called_once()
461
462    @patch.object(
463        TestAgent, "_monitor_workers", return_value=monres(WorkerState.UNKNOWN)
464    )
465    def test_run_unknown_state(self, mock_monitor_workers):
466        # when the state is unknown we exit immediately; no retries
467        spec = self._get_worker_spec(max_restarts=100, monitor_interval=0.1)
468        agent = TestAgent(spec)
469        worker_group = agent._worker_group
470
471        with self.assertRaises(Exception):
472            agent.run()
473
474        self.assertEqual(WorkerState.UNKNOWN, worker_group.state)
475        self.assertEqual(1, mock_monitor_workers.call_count)
476        self.assertEqual(spec.max_restarts, agent._remaining_restarts)
477
478    def test_assign_worker_ranks(self):
479        role_infos = [
480            _RoleInstanceInfo("parameter_server", 0, 4),
481            _RoleInstanceInfo("trainer", 1, 1),
482            _RoleInstanceInfo("trainer", 2, 2),
483            _RoleInstanceInfo("trainer", 3, 3),
484            _RoleInstanceInfo("parameter_server", 4, 5),
485        ]
486        store = dist.HashStore()
487
488        def f(info) -> List[Worker]:
489            i, role_info = info
490            spec = self._get_worker_spec(
491                max_restarts=3,
492                monitor_interval=0.1,
493                role=role_info.role,
494                local_world_size=role_info.local_world_size,
495            )
496            agent = TestAgent(spec)
497            workers = agent._assign_worker_ranks(
498                store, role_info.rank, len(role_infos), spec
499            )
500            return [
501                (
502                    w.local_rank,
503                    w.role_rank,
504                    w.global_rank,
505                    w.world_size,
506                    w.role_world_size,
507                )
508                for w in workers
509            ]
510
511        with ThreadPool(len(role_infos)) as pool:
512            out = pool.map(f, enumerate(role_infos))
513
514        self.assertListEqual(
515            out,
516            [
517                [
518                    (0, 0, 0, 15, 9),
519                    (1, 1, 1, 15, 9),
520                    (2, 2, 2, 15, 9),
521                    (3, 3, 3, 15, 9),
522                ],
523                [
524                    (0, 0, 4, 15, 6),
525                ],
526                [
527                    (0, 1, 5, 15, 6),
528                    (1, 2, 6, 15, 6),
529                ],
530                [
531                    (0, 3, 7, 15, 6),
532                    (1, 4, 8, 15, 6),
533                    (2, 5, 9, 15, 6),
534                ],
535                [
536                    (0, 4, 10, 15, 9),
537                    (1, 5, 11, 15, 9),
538                    (2, 6, 12, 15, 9),
539                    (3, 7, 13, 15, 9),
540                    (4, 8, 14, 15, 9),
541                ],
542            ],
543        )
544
545    def test_get_event(self):
546        spec = self._get_worker_spec(max_restarts=1)
547        agent = TestAgent(spec)
548        event = agent.get_event_succeeded()
549        self.assertEqual("AGENT", event.source)
550        self.assertEqual("static", event.metadata["rdzv_backend"])
551        self.assertEqual("SUCCEEDED", event.metadata["state"])
552        self.assertEqual(spec.role, event.metadata["role"])
553
554    def test_get_worker_status_event(self):
555        spec = self._get_worker_spec(max_restarts=4)
556        agent = TestAgent(spec)
557        agent._remaining_restarts = spec.max_restarts - 2
558        actual_event = agent._construct_event(
559            state="SUCCEEDED",
560            source="WORKER",
561            worker=agent._worker_group.workers[0],
562        )
563        self.assertEqual("WORKER", actual_event.source)
564        self.assertEqual("static", actual_event.metadata["rdzv_backend"])
565        self.assertEqual("SUCCEEDED", actual_event.metadata["state"])
566        self.assertEqual(spec.role, actual_event.metadata["role"])
567        self.assertEqual(2, actual_event.metadata["agent_restarts"])
568
569    @patch("torch.distributed.elastic.agent.server.api.put_metric")
570    @patch.object(TestAgent, "_invoke_run")
571    def test_agent_process_signal_exception(self, invoke_run, _):
572        spec = self._get_worker_spec(max_restarts=0)
573        agent = TestAgent(spec)
574        invoke_run.side_effect = SignalException(
575            "signal exception", sigval=signal.SIGTERM
576        )
577        with patch.object(agent, "_shutdown") as shutdown_mock:
578            with self.assertRaises(SignalException):
579                agent.run()
580            args, _ = shutdown_mock.call_args
581            self.assertEqual(signal.SIGTERM, args[0])
582
583    @patch("torch.distributed.elastic.agent.server.api.put_metric")
584    @patch.object(TestAgent, "_invoke_run")
585    def test_agent_process_handler_graceful_exception(self, invoke_run, _):
586        spec = self._get_worker_spec(max_restarts=0)
587        agent = TestAgent(spec)
588        invoke_run.side_effect = RendezvousGracefulExitError()
589        with patch.object(agent, "_shutdown"):
590            agent.run()
591
592
593if __name__ == "__main__":
594    run_tests()
595