1# Owner(s): ["oncall: r2p"] 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9from typing import Any, cast, Dict, SupportsInt 10from unittest import TestCase 11 12from torch.distributed.elastic.rendezvous import ( 13 RendezvousHandler, 14 RendezvousHandlerRegistry, 15 RendezvousInfo, 16 RendezvousParameters, 17) 18 19 20class RendezvousParametersTest(TestCase): 21 def setUp(self) -> None: 22 self._backend = "dummy_backend" 23 self._endpoint = "dummy_endpoint" 24 self._run_id = "dummy_run_id" 25 self._min_nodes = 3 26 self._max_nodes = 6 27 self._kwargs: Dict[str, Any] = {} 28 29 def _create_params(self) -> RendezvousParameters: 30 return RendezvousParameters( 31 backend=self._backend, 32 endpoint=self._endpoint, 33 run_id=self._run_id, 34 min_nodes=self._min_nodes, 35 max_nodes=self._max_nodes, 36 **self._kwargs, 37 ) 38 39 def test_init_initializes_params(self) -> None: 40 self._kwargs["dummy_param"] = "x" 41 42 params = self._create_params() 43 44 self.assertEqual(params.backend, self._backend) 45 self.assertEqual(params.endpoint, self._endpoint) 46 self.assertEqual(params.run_id, self._run_id) 47 self.assertEqual(params.min_nodes, self._min_nodes) 48 self.assertEqual(params.max_nodes, self._max_nodes) 49 50 self.assertEqual(params.get("dummy_param"), "x") 51 52 def test_init_initializes_params_if_min_nodes_equals_to_1(self) -> None: 53 self._min_nodes = 1 54 55 params = self._create_params() 56 57 self.assertEqual(params.min_nodes, self._min_nodes) 58 self.assertEqual(params.max_nodes, self._max_nodes) 59 60 def test_init_initializes_params_if_min_and_max_nodes_are_equal(self) -> None: 61 self._max_nodes = 3 62 63 params = self._create_params() 64 65 self.assertEqual(params.min_nodes, self._min_nodes) 66 self.assertEqual(params.max_nodes, self._max_nodes) 67 68 def test_init_raises_error_if_backend_is_none_or_empty(self) -> None: 69 for backend in [None, ""]: 70 with self.subTest(backend=backend): 71 self._backend = backend # type: ignore[assignment] 72 73 with self.assertRaisesRegex( 74 ValueError, 75 r"^The rendezvous backend name must be a non-empty string.$", 76 ): 77 self._create_params() 78 79 def test_init_raises_error_if_min_nodes_is_less_than_1(self) -> None: 80 for min_nodes in [0, -1, -5]: 81 with self.subTest(min_nodes=min_nodes): 82 self._min_nodes = min_nodes 83 84 with self.assertRaisesRegex( 85 ValueError, 86 rf"^The minimum number of rendezvous nodes \({min_nodes}\) must be greater " 87 rf"than zero.$", 88 ): 89 self._create_params() 90 91 def test_init_raises_error_if_max_nodes_is_less_than_min_nodes(self) -> None: 92 for max_nodes in [2, 1, -2]: 93 with self.subTest(max_nodes=max_nodes): 94 self._max_nodes = max_nodes 95 96 with self.assertRaisesRegex( 97 ValueError, 98 rf"^The maximum number of rendezvous nodes \({max_nodes}\) must be greater " 99 "than or equal to the minimum number of rendezvous nodes " 100 rf"\({self._min_nodes}\).$", 101 ): 102 self._create_params() 103 104 def test_get_returns_none_if_key_does_not_exist(self) -> None: 105 params = self._create_params() 106 107 self.assertIsNone(params.get("dummy_param")) 108 109 def test_get_returns_default_if_key_does_not_exist(self) -> None: 110 params = self._create_params() 111 112 self.assertEqual(params.get("dummy_param", default="x"), "x") 113 114 def test_get_as_bool_returns_none_if_key_does_not_exist(self) -> None: 115 params = self._create_params() 116 117 self.assertIsNone(params.get_as_bool("dummy_param")) 118 119 def test_get_as_bool_returns_default_if_key_does_not_exist(self) -> None: 120 params = self._create_params() 121 122 self.assertTrue(params.get_as_bool("dummy_param", default=True)) 123 124 def test_get_as_bool_returns_true_if_value_represents_true(self) -> None: 125 for value in ["1", "True", "tRue", "T", "t", "yEs", "Y", 1, True]: 126 with self.subTest(value=value): 127 self._kwargs["dummy_param"] = value 128 129 params = self._create_params() 130 131 self.assertTrue(params.get_as_bool("dummy_param")) 132 133 def test_get_as_bool_returns_false_if_value_represents_false(self) -> None: 134 for value in ["0", "False", "faLse", "F", "f", "nO", "N", 0, False]: 135 with self.subTest(value=value): 136 self._kwargs["dummy_param"] = value 137 138 params = self._create_params() 139 140 self.assertFalse(params.get_as_bool("dummy_param")) 141 142 def test_get_as_bool_raises_error_if_value_is_invalid(self) -> None: 143 for value in ["01", "Flse", "Ture", "g", "4", "_", "truefalse", 2, -1]: 144 with self.subTest(value=value): 145 self._kwargs["dummy_param"] = value 146 147 params = self._create_params() 148 149 with self.assertRaisesRegex( 150 ValueError, 151 r"^The rendezvous configuration option 'dummy_param' does not represent a " 152 r"valid boolean value.$", 153 ): 154 params.get_as_bool("dummy_param") 155 156 def test_get_as_int_returns_none_if_key_does_not_exist(self) -> None: 157 params = self._create_params() 158 159 self.assertIsNone(params.get_as_int("dummy_param")) 160 161 def test_get_as_int_returns_default_if_key_does_not_exist(self) -> None: 162 params = self._create_params() 163 164 self.assertEqual(params.get_as_int("dummy_param", default=5), 5) 165 166 def test_get_as_int_returns_integer_if_value_represents_integer(self) -> None: 167 for value in ["0", "-10", "5", " 4", "4 ", " 4 ", 0, -4, 3]: 168 with self.subTest(value=value): 169 self._kwargs["dummy_param"] = value 170 171 params = self._create_params() 172 173 self.assertEqual( 174 params.get_as_int("dummy_param"), int(cast(SupportsInt, value)) 175 ) 176 177 def test_get_as_int_raises_error_if_value_is_invalid(self) -> None: 178 for value in ["a", "0a", "3b", "abc"]: 179 with self.subTest(value=value): 180 self._kwargs["dummy_param"] = value 181 182 params = self._create_params() 183 184 with self.assertRaisesRegex( 185 ValueError, 186 r"^The rendezvous configuration option 'dummy_param' does not represent a " 187 r"valid integer value.$", 188 ): 189 params.get_as_int("dummy_param") 190 191 192class _DummyRendezvousHandler(RendezvousHandler): 193 def __init__(self, params: RendezvousParameters) -> None: 194 self.params = params 195 196 def get_backend(self) -> str: 197 return "dummy_backend" 198 199 def next_rendezvous(self) -> RendezvousInfo: 200 raise NotImplementedError 201 202 def is_closed(self) -> bool: 203 return False 204 205 def set_closed(self) -> None: 206 pass 207 208 def num_nodes_waiting(self) -> int: 209 return 0 210 211 def get_run_id(self) -> str: 212 return "" 213 214 def shutdown(self) -> bool: 215 return False 216 217 218class RendezvousHandlerRegistryTest(TestCase): 219 def setUp(self) -> None: 220 self._params = RendezvousParameters( 221 backend="dummy_backend", 222 endpoint="dummy_endpoint", 223 run_id="dummy_run_id", 224 min_nodes=1, 225 max_nodes=1, 226 ) 227 228 self._registry = RendezvousHandlerRegistry() 229 230 @staticmethod 231 def _create_handler(params: RendezvousParameters) -> RendezvousHandler: 232 return _DummyRendezvousHandler(params) 233 234 def test_register_registers_once_if_called_twice_with_same_creator(self) -> None: 235 self._registry.register("dummy_backend", self._create_handler) 236 self._registry.register("dummy_backend", self._create_handler) 237 238 def test_register_raises_error_if_called_twice_with_different_creators( 239 self, 240 ) -> None: 241 self._registry.register("dummy_backend", self._create_handler) 242 243 other_create_handler = lambda p: _DummyRendezvousHandler(p) # noqa: E731 244 245 with self.assertRaisesRegex( 246 ValueError, 247 r"^The rendezvous backend 'dummy_backend' cannot be registered with " 248 rf"'{other_create_handler}' as it is already registered with '{self._create_handler}'.$", 249 ): 250 self._registry.register("dummy_backend", other_create_handler) 251 252 def test_create_handler_returns_handler(self) -> None: 253 self._registry.register("dummy_backend", self._create_handler) 254 255 handler = self._registry.create_handler(self._params) 256 257 self.assertIsInstance(handler, _DummyRendezvousHandler) 258 259 self.assertIs(handler.params, self._params) 260 261 def test_create_handler_raises_error_if_backend_is_not_registered(self) -> None: 262 with self.assertRaisesRegex( 263 ValueError, 264 r"^The rendezvous backend 'dummy_backend' is not registered. Did you forget to call " 265 r"`register`\?$", 266 ): 267 self._registry.create_handler(self._params) 268 269 def test_create_handler_raises_error_if_backend_names_do_not_match(self) -> None: 270 self._registry.register("dummy_backend_2", self._create_handler) 271 272 with self.assertRaisesRegex( 273 RuntimeError, 274 r"^The rendezvous backend 'dummy_backend' does not match the requested backend " 275 r"'dummy_backend_2'.$", 276 ): 277 self._params.backend = "dummy_backend_2" 278 279 self._registry.create_handler(self._params) 280