1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9 10 11import signal 12import unittest 13import uuid 14from multiprocessing.pool import ThreadPool 15from typing import Any, Dict, List 16from unittest.mock import call, patch 17 18import torch.distributed as dist 19import torch.distributed.elastic.rendezvous.registry as rdzv_registry 20from torch.distributed.elastic.agent.server.api import ( 21 _get_fq_hostname, 22 _RoleInstanceInfo, 23 RunResult, 24 SimpleElasticAgent, 25 Worker, 26 WorkerGroup, 27 WorkerSpec, 28 WorkerState, 29) 30from torch.distributed.elastic.multiprocessing import SignalException 31from torch.distributed.elastic.multiprocessing.errors import ProcessFailure 32from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters 33from torch.distributed.elastic.rendezvous.api import RendezvousGracefulExitError 34from torch.distributed.elastic.utils.distributed import get_free_port 35from torch.testing._internal.common_utils import run_tests 36 37 38def do_nothing(): 39 pass 40 41 42class WorkerStateTest(unittest.TestCase): 43 def test_is_running(self): 44 for state in WorkerState: 45 if state == WorkerState.HEALTHY or state == WorkerState.UNHEALTHY: 46 self.assertTrue(WorkerState.is_running(state)) 47 else: 48 self.assertFalse(WorkerState.is_running(state)) 49 50 51class WorkerGroupTest(unittest.TestCase): 52 def test_worker_group_constructor(self): 53 spec = WorkerSpec( 54 role="test_trainer", 55 local_world_size=4, 56 fn=do_nothing, 57 args=(), 58 rdzv_handler=None, 59 max_restarts=50, 60 monitor_interval=0.1, 61 ) 62 worker_group = WorkerGroup(spec) 63 64 self.assertEqual(WorkerState.INIT, worker_group.state) 65 66 workers = worker_group.workers 67 self.assertEqual(4, len(workers)) 68 69 # validate full, consecutive local ranks 70 self.assertSetEqual(set(range(4)), {w.local_rank for w in workers}) 71 72 # global_rank, world_size are assigned after rdzv 73 # id is assigned after starting worker (by the agent) 74 # validate there are None 75 for w in workers: 76 self.assertEqual(-1, w.global_rank) 77 self.assertEqual(-1, w.world_size) 78 self.assertEqual(None, w.id) 79 80 # rank and store are assigned after rdzv; validate that they are None 81 self.assertIsNone(worker_group.group_rank) 82 self.assertIsNone(worker_group.store) 83 84 85class RoleInstanceInfoTest(unittest.TestCase): 86 def test_compare(self): 87 agent_role1 = _RoleInstanceInfo("role", 1, 10) 88 agent_role2 = _RoleInstanceInfo("role", 2, 10) 89 self.assertEqual(1, _RoleInstanceInfo.compare(agent_role2, agent_role1)) 90 agent_role1 = _RoleInstanceInfo("role1", 1, 10) 91 agent_role2 = _RoleInstanceInfo("role2", 2, 10) 92 self.assertEqual(-1, _RoleInstanceInfo.compare(agent_role1, agent_role2)) 93 agent_role1 = _RoleInstanceInfo("role1", 1, 10) 94 agent_role2 = _RoleInstanceInfo("role2", 1, 10) 95 self.assertEqual(-1, _RoleInstanceInfo.compare(agent_role1, agent_role2)) 96 97 def test_serde(self): 98 agent_role = _RoleInstanceInfo("role", 1, 10) 99 str_data = agent_role.serialize() 100 actual_agent_role = _RoleInstanceInfo.deserialize(str_data) 101 self.assertEqual(agent_role.role, actual_agent_role.role) 102 self.assertEqual(agent_role.rank, actual_agent_role.rank) 103 self.assertEqual( 104 agent_role.local_world_size, actual_agent_role.local_world_size 105 ) 106 107 def test_find_boundaries(self): 108 role_infos = [ 109 _RoleInstanceInfo("trainer", 1, 1), 110 _RoleInstanceInfo("trainer", 2, 2), 111 _RoleInstanceInfo("trainer", 3, 3), 112 _RoleInstanceInfo("parameter_server", 4, 5), 113 _RoleInstanceInfo("parameter_server", 0, 4), 114 ] 115 start_idx, end_idx = _RoleInstanceInfo.find_role_boundaries( 116 role_infos, "trainer" 117 ) 118 self.assertEqual(start_idx, 0) 119 self.assertEqual(end_idx, 2) 120 121 122class TestAgent(SimpleElasticAgent): 123 def __init__(self, spec): 124 super().__init__(spec) 125 self.stop_workers_call_count = 0 126 self.start_workers_call_count = 0 127 128 def _stop_workers( 129 self, worker_group: WorkerGroup, is_restart: bool = False 130 ) -> None: 131 # workers are fake, nothing to stop; just clear the rdzv info 132 worker_group.group_rank = None 133 worker_group.group_world_size = None 134 self.stop_workers_call_count += 1 135 136 def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]: 137 # crate fake workers; make worker id equal to global rank 138 ids = {} 139 for worker in worker_group.workers: 140 ids[worker.local_rank] = worker.global_rank 141 self.start_workers_call_count += 1 142 return ids 143 144 def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult: 145 raise NotImplementedError("mock this method") 146 147 def _shutdown(self): 148 pass 149 150 151def monres(state: WorkerState): 152 if state == WorkerState.SUCCEEDED: 153 return RunResult(state=state, return_values={0: 0}, failures={}) 154 elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: 155 pf = ProcessFailure(local_rank=0, pid=999, exitcode=1, error_file="<none>") 156 return RunResult(state=state, return_values={}, failures={0: pf}) 157 else: 158 return RunResult(state=state) 159 160 161class SimpleElasticAgentTest(unittest.TestCase): 162 def _get_worker_spec( 163 self, 164 max_restarts=1, 165 monitor_interval=0.1, 166 role="test_trainer", 167 local_world_size=8, 168 local_addr=None, 169 ): 170 run_id = str(uuid.uuid4().int) 171 port = get_free_port() 172 if local_addr is None: 173 endpoint = f"127.0.0.1:{port}" 174 else: 175 endpoint = f"{local_addr}:{port}" 176 177 rdzv_params = RendezvousParameters( 178 backend="static", 179 endpoint=endpoint, 180 run_id=run_id, 181 min_nodes=1, 182 max_nodes=1, 183 rank=0, 184 ) 185 rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) 186 spec = WorkerSpec( 187 role=role, 188 local_world_size=local_world_size, 189 fn=do_nothing, 190 args=(), 191 rdzv_handler=rdzv_handler, 192 max_restarts=max_restarts, 193 monitor_interval=monitor_interval, 194 local_addr=local_addr, 195 ) 196 return spec 197 198 def test_agent_constructor(self): 199 spec = self._get_worker_spec(max_restarts=1) 200 agent = TestAgent(spec) 201 worker_group = agent.get_worker_group() 202 self.assertEqual(WorkerState.INIT, worker_group.state) 203 self.assertEqual(spec.max_restarts, agent._remaining_restarts) 204 205 @patch("torch.distributed.elastic.agent.server.api.put_metric") 206 def test_record_flakiness_metric(self, put_metric_mock): 207 spec = self._get_worker_spec(max_restarts=1) 208 agent = TestAgent(spec) 209 agent._record_flakiness_metric() 210 put_metric_mock.assert_called_with("workers.test_trainer.flakiness", 0) 211 agent._worker_group.spec.max_restarts = 10 212 agent._remaining_restarts = 3 213 agent._record_flakiness_metric() 214 put_metric_mock.assert_called_with("workers.test_trainer.flakiness", 63) 215 216 @patch("torch.distributed.elastic.agent.server.api.put_metric") 217 def test_record_flakiness_metric_zero_restarts(self, put_metric_mock): 218 spec = self._get_worker_spec(max_restarts=1) 219 spec.max_restarts = 0 220 agent = TestAgent(spec) 221 agent._record_flakiness_metric() 222 put_metric_mock.assert_called_with("workers.test_trainer.flakiness", 0) 223 224 @patch("torch.distributed.elastic.agent.server.api.put_metric") 225 def test_record_flakiness_metric_user_exception(self, put_metric_mock): 226 spec = self._get_worker_spec(max_restarts=1) 227 agent = TestAgent(spec) 228 agent._record_flakiness_metric(True) 229 put_metric_mock.assert_called_with("workers.test_trainer.flakiness", 100) 230 231 @patch.object(TestAgent, "_invoke_run") 232 @patch.object(TestAgent, "_record_metrics") 233 @patch.object(TestAgent, "_record_worker_events") 234 @patch.object(TestAgent, "_shutdown") 235 def test_invoke_run( 236 self, shutdown_mock, record_events_mock, record_metrics_mock, invoke_run_mock 237 ): 238 spec = self._get_worker_spec(max_restarts=1) 239 agent = TestAgent(spec) 240 agent.run() 241 invoke_run_mock.assert_called_once() 242 record_metrics_mock.assert_called_once() 243 record_events_mock.assert_called_once() 244 shutdown_mock.assert_called_once() 245 246 @patch("torch.distributed.elastic.agent.server.api.put_metric") 247 def test_record_metrics_success_no_retries(self, put_metric_mock): 248 spec = self._get_worker_spec(max_restarts=1) 249 agent = TestAgent(spec) 250 group_result = RunResult({}, {}) 251 agent._record_metrics(group_result) 252 calls = self._get_record_metrics_test_calls(success_no_retries=1) 253 put_metric_mock.assert_has_calls(calls, any_order=True) 254 255 @patch("torch.distributed.elastic.agent.server.api.put_metric") 256 def test_record_metrics_success_with_retries(self, put_metric_mock): 257 spec = self._get_worker_spec(max_restarts=10) 258 agent = TestAgent(spec) 259 agent._remaining_restarts = 2 260 group_result = RunResult({}, {}) 261 agent._record_metrics(group_result) 262 calls = self._get_record_metrics_test_calls(success_with_retries=1) 263 put_metric_mock.assert_has_calls(calls, any_order=True) 264 265 @patch("torch.distributed.elastic.agent.server.api.put_metric") 266 def test_record_metrics_failed_with_retries(self, put_metric_mock): 267 spec = self._get_worker_spec(max_restarts=10) 268 agent = TestAgent(spec) 269 agent._remaining_restarts = 2 270 group_result = RunResult( 271 state=WorkerState.FAILED, return_values={}, failures={0: 0} 272 ) 273 agent._record_metrics(group_result) 274 calls = self._get_record_metrics_test_calls(failed_with_retries=1) 275 put_metric_mock.assert_has_calls(calls, any_order=True) 276 277 @patch("torch.distributed.elastic.agent.server.api.put_metric") 278 def test_record_metrics_failed_no_retries(self, put_metric_mock): 279 spec = self._get_worker_spec(max_restarts=10) 280 agent = TestAgent(spec) 281 group_result = RunResult( 282 state=WorkerState.FAILED, return_values={}, failures={0: 0} 283 ) 284 agent._record_metrics(group_result) 285 calls = self._get_record_metrics_test_calls(failed_no_retries=1) 286 put_metric_mock.assert_has_calls(calls, any_order=True) 287 288 def _get_record_metrics_test_calls( 289 self, 290 success_with_retries=0, 291 success_no_retries=0, 292 failed_with_retries=0, 293 failed_no_retries=0, 294 ): 295 calls = [ 296 call("workers.test_trainer.run_success_with_retries", success_with_retries), 297 call("workers.test_trainer.run_success_no_retries", success_no_retries), 298 call("workers.test_trainer.run_failed_with_retries", failed_with_retries), 299 call("workers.test_trainer.run_failed_no_retries", failed_no_retries), 300 ] 301 return calls 302 303 def test_rendezvous(self): 304 hostname = _get_fq_hostname() 305 spec = self._get_worker_spec(max_restarts=1, local_addr=hostname) 306 agent = TestAgent(spec) 307 worker_group = agent.get_worker_group() 308 agent._rendezvous(worker_group) 309 310 # single agent rdzv 311 self.assertEqual(1, worker_group.group_world_size) 312 self.assertEqual(0, worker_group.group_rank) 313 314 self.assertEqual(hostname, worker_group.master_addr) 315 self.assertTrue(worker_group.master_port > 0) 316 317 rank_set = {w.global_rank for w in worker_group.workers} 318 for w in worker_group.workers: 319 self.assertIsNone(w.id) 320 local_world_size = spec.local_world_size 321 group_world_size = worker_group.group_world_size 322 group_rank = worker_group.group_rank 323 324 self.assertEqual(local_world_size * group_world_size, w.world_size) 325 self.assertEqual( 326 local_world_size * group_rank + w.local_rank, w.global_rank 327 ) 328 self.assertSetEqual(set(range(w.world_size)), rank_set) 329 330 def test_rendezvous_default_master_addr(self): 331 hostname = _get_fq_hostname() 332 spec = self._get_worker_spec(max_restarts=1, local_addr=hostname) 333 agent = TestAgent(spec) 334 worker_group = agent.get_worker_group() 335 agent._rendezvous(worker_group) 336 337 self.assertEqual(_get_fq_hostname(), worker_group.master_addr) 338 self.assertGreater(worker_group.master_port, 0) 339 340 def test_rendezvous_master_addr_with_local_addr(self): 341 spec_local_addr = "127.0.0.1" 342 spec = self._get_worker_spec(max_restarts=1, local_addr=spec_local_addr) 343 agent = TestAgent(spec) 344 worker_group = agent.get_worker_group() 345 agent._rendezvous(worker_group) 346 347 self.assertNotEqual(_get_fq_hostname(), worker_group.master_addr) 348 self.assertEqual(spec_local_addr, worker_group.master_addr) 349 self.assertGreater(worker_group.master_port, 0) 350 351 def test_initialize_workers(self): 352 spec = self._get_worker_spec(max_restarts=1) 353 agent = TestAgent(spec) 354 worker_group = agent.get_worker_group() 355 agent._initialize_workers(worker_group) 356 357 self.assertEqual(WorkerState.HEALTHY, worker_group.state) 358 for i in range(spec.local_world_size): 359 worker = worker_group.workers[i] 360 self.assertEqual(worker.id, worker.global_rank) 361 362 def test_restart_workers(self): 363 spec = self._get_worker_spec() 364 agent = TestAgent(spec) 365 worker_group = agent.get_worker_group() 366 367 num_restarts = 3 368 for _ in range(0, num_restarts): 369 agent._restart_workers(worker_group) 370 self.assertEqual(WorkerState.HEALTHY, worker_group.state) 371 372 # test_rendezvous and test_initialize_workers 373 # already validates the correctness of these fields 374 # simply validate that they are not None 375 # (e.g. that they get assigned) 376 self.assertIsNotNone(worker_group.group_rank) 377 self.assertIsNotNone(worker_group.group_world_size) 378 for w in worker_group.workers: 379 self.assertIsNotNone(w.id) 380 self.assertIsNotNone(w.global_rank) 381 self.assertIsNotNone(w.world_size) 382 383 self.assertEqual(num_restarts, agent.start_workers_call_count) 384 self.assertEqual(num_restarts, agent.stop_workers_call_count) 385 386 @patch.object( 387 TestAgent, 388 "_monitor_workers", 389 side_effect=[ 390 monres(WorkerState.HEALTHY), 391 monres(WorkerState.HEALTHY), 392 monres(WorkerState.SUCCEEDED), 393 ], 394 ) 395 @patch.object(TestAgent, "_record_worker_events") 396 def test_run_happy_path(self, record_events_mock, mock_monitor_workers): 397 # worker starts 398 # is always healthy 399 # then succeeds 400 max_restarts = 10 401 spec = self._get_worker_spec(max_restarts) 402 agent = TestAgent(spec) 403 404 agent.run() 405 406 # no failure, no membership changes -> no retries 407 self.assertEqual(max_restarts, agent._remaining_restarts) 408 record_events_mock.assert_called_once() 409 410 @patch.object(TestAgent, "_initialize_workers", side_effect=RuntimeError()) 411 def test_run_initialization_failure(self, mock_initialize_workers): 412 spec = self._get_worker_spec() 413 agent = TestAgent(spec) 414 worker_group = agent._worker_group 415 416 with self.assertRaises(RuntimeError): 417 agent.run() 418 419 self.assertEqual(WorkerState.INIT, worker_group.state) 420 421 def test_run_max_retries_exceeded(self): 422 for restartable_state in [ 423 monres(WorkerState.FAILED), 424 monres(WorkerState.UNHEALTHY), 425 ]: 426 with patch.object( 427 TestAgent, "_monitor_workers", return_value=restartable_state 428 ) as mock_monitor_workers: 429 spec = self._get_worker_spec(max_restarts=3, monitor_interval=0.1) 430 agent = TestAgent(spec) 431 worker_group = agent._worker_group 432 433 agent.run() 434 self.assertEqual(WorkerState.FAILED, worker_group.state) 435 self.assertEqual(0, agent._remaining_restarts) 436 # one monitor call for each retry + one to monitor the last retry 437 self.assertEqual(spec.max_restarts + 1, mock_monitor_workers.call_count) 438 439 @patch.object( 440 TestAgent, 441 "_monitor_workers", 442 side_effect=[ 443 monres(WorkerState.HEALTHY), 444 monres(WorkerState.HEALTHY), 445 monres(WorkerState.HEALTHY), 446 monres(WorkerState.SUCCEEDED), 447 ], 448 ) 449 @patch.object(RendezvousHandler, "num_nodes_waiting", side_effect=[1, 1, 0]) 450 @patch.object(TestAgent, "_record_worker_events") 451 def test_run_membership_change( 452 self, record_events_mock, mock_num_nodes_waiting, mock_monitor_workers 453 ): 454 spec = self._get_worker_spec(max_restarts=1, monitor_interval=0.1) 455 agent = TestAgent(spec) 456 worker_group = agent._worker_group 457 458 agent.run() 459 self.assertEqual(WorkerState.SUCCEEDED, worker_group.state) 460 record_events_mock.assert_called_once() 461 462 @patch.object( 463 TestAgent, "_monitor_workers", return_value=monres(WorkerState.UNKNOWN) 464 ) 465 def test_run_unknown_state(self, mock_monitor_workers): 466 # when the state is unknown we exit immediately; no retries 467 spec = self._get_worker_spec(max_restarts=100, monitor_interval=0.1) 468 agent = TestAgent(spec) 469 worker_group = agent._worker_group 470 471 with self.assertRaises(Exception): 472 agent.run() 473 474 self.assertEqual(WorkerState.UNKNOWN, worker_group.state) 475 self.assertEqual(1, mock_monitor_workers.call_count) 476 self.assertEqual(spec.max_restarts, agent._remaining_restarts) 477 478 def test_assign_worker_ranks(self): 479 role_infos = [ 480 _RoleInstanceInfo("parameter_server", 0, 4), 481 _RoleInstanceInfo("trainer", 1, 1), 482 _RoleInstanceInfo("trainer", 2, 2), 483 _RoleInstanceInfo("trainer", 3, 3), 484 _RoleInstanceInfo("parameter_server", 4, 5), 485 ] 486 store = dist.HashStore() 487 488 def f(info) -> List[Worker]: 489 i, role_info = info 490 spec = self._get_worker_spec( 491 max_restarts=3, 492 monitor_interval=0.1, 493 role=role_info.role, 494 local_world_size=role_info.local_world_size, 495 ) 496 agent = TestAgent(spec) 497 workers = agent._assign_worker_ranks( 498 store, role_info.rank, len(role_infos), spec 499 ) 500 return [ 501 ( 502 w.local_rank, 503 w.role_rank, 504 w.global_rank, 505 w.world_size, 506 w.role_world_size, 507 ) 508 for w in workers 509 ] 510 511 with ThreadPool(len(role_infos)) as pool: 512 out = pool.map(f, enumerate(role_infos)) 513 514 self.assertListEqual( 515 out, 516 [ 517 [ 518 (0, 0, 0, 15, 9), 519 (1, 1, 1, 15, 9), 520 (2, 2, 2, 15, 9), 521 (3, 3, 3, 15, 9), 522 ], 523 [ 524 (0, 0, 4, 15, 6), 525 ], 526 [ 527 (0, 1, 5, 15, 6), 528 (1, 2, 6, 15, 6), 529 ], 530 [ 531 (0, 3, 7, 15, 6), 532 (1, 4, 8, 15, 6), 533 (2, 5, 9, 15, 6), 534 ], 535 [ 536 (0, 4, 10, 15, 9), 537 (1, 5, 11, 15, 9), 538 (2, 6, 12, 15, 9), 539 (3, 7, 13, 15, 9), 540 (4, 8, 14, 15, 9), 541 ], 542 ], 543 ) 544 545 def test_get_event(self): 546 spec = self._get_worker_spec(max_restarts=1) 547 agent = TestAgent(spec) 548 event = agent.get_event_succeeded() 549 self.assertEqual("AGENT", event.source) 550 self.assertEqual("static", event.metadata["rdzv_backend"]) 551 self.assertEqual("SUCCEEDED", event.metadata["state"]) 552 self.assertEqual(spec.role, event.metadata["role"]) 553 554 def test_get_worker_status_event(self): 555 spec = self._get_worker_spec(max_restarts=4) 556 agent = TestAgent(spec) 557 agent._remaining_restarts = spec.max_restarts - 2 558 actual_event = agent._construct_event( 559 state="SUCCEEDED", 560 source="WORKER", 561 worker=agent._worker_group.workers[0], 562 ) 563 self.assertEqual("WORKER", actual_event.source) 564 self.assertEqual("static", actual_event.metadata["rdzv_backend"]) 565 self.assertEqual("SUCCEEDED", actual_event.metadata["state"]) 566 self.assertEqual(spec.role, actual_event.metadata["role"]) 567 self.assertEqual(2, actual_event.metadata["agent_restarts"]) 568 569 @patch("torch.distributed.elastic.agent.server.api.put_metric") 570 @patch.object(TestAgent, "_invoke_run") 571 def test_agent_process_signal_exception(self, invoke_run, _): 572 spec = self._get_worker_spec(max_restarts=0) 573 agent = TestAgent(spec) 574 invoke_run.side_effect = SignalException( 575 "signal exception", sigval=signal.SIGTERM 576 ) 577 with patch.object(agent, "_shutdown") as shutdown_mock: 578 with self.assertRaises(SignalException): 579 agent.run() 580 args, _ = shutdown_mock.call_args 581 self.assertEqual(signal.SIGTERM, args[0]) 582 583 @patch("torch.distributed.elastic.agent.server.api.put_metric") 584 @patch.object(TestAgent, "_invoke_run") 585 def test_agent_process_handler_graceful_exception(self, invoke_run, _): 586 spec = self._get_worker_spec(max_restarts=0) 587 agent = TestAgent(spec) 588 invoke_run.side_effect = RendezvousGracefulExitError() 589 with patch.object(agent, "_shutdown"): 590 agent.run() 591 592 593if __name__ == "__main__": 594 run_tests() 595