xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/rendezvous/api_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: r2p"]
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9from typing import Any, cast, Dict, SupportsInt
10from unittest import TestCase
11
12from torch.distributed.elastic.rendezvous import (
13    RendezvousHandler,
14    RendezvousHandlerRegistry,
15    RendezvousInfo,
16    RendezvousParameters,
17)
18
19
20class RendezvousParametersTest(TestCase):
21    def setUp(self) -> None:
22        self._backend = "dummy_backend"
23        self._endpoint = "dummy_endpoint"
24        self._run_id = "dummy_run_id"
25        self._min_nodes = 3
26        self._max_nodes = 6
27        self._kwargs: Dict[str, Any] = {}
28
29    def _create_params(self) -> RendezvousParameters:
30        return RendezvousParameters(
31            backend=self._backend,
32            endpoint=self._endpoint,
33            run_id=self._run_id,
34            min_nodes=self._min_nodes,
35            max_nodes=self._max_nodes,
36            **self._kwargs,
37        )
38
39    def test_init_initializes_params(self) -> None:
40        self._kwargs["dummy_param"] = "x"
41
42        params = self._create_params()
43
44        self.assertEqual(params.backend, self._backend)
45        self.assertEqual(params.endpoint, self._endpoint)
46        self.assertEqual(params.run_id, self._run_id)
47        self.assertEqual(params.min_nodes, self._min_nodes)
48        self.assertEqual(params.max_nodes, self._max_nodes)
49
50        self.assertEqual(params.get("dummy_param"), "x")
51
52    def test_init_initializes_params_if_min_nodes_equals_to_1(self) -> None:
53        self._min_nodes = 1
54
55        params = self._create_params()
56
57        self.assertEqual(params.min_nodes, self._min_nodes)
58        self.assertEqual(params.max_nodes, self._max_nodes)
59
60    def test_init_initializes_params_if_min_and_max_nodes_are_equal(self) -> None:
61        self._max_nodes = 3
62
63        params = self._create_params()
64
65        self.assertEqual(params.min_nodes, self._min_nodes)
66        self.assertEqual(params.max_nodes, self._max_nodes)
67
68    def test_init_raises_error_if_backend_is_none_or_empty(self) -> None:
69        for backend in [None, ""]:
70            with self.subTest(backend=backend):
71                self._backend = backend  # type: ignore[assignment]
72
73                with self.assertRaisesRegex(
74                    ValueError,
75                    r"^The rendezvous backend name must be a non-empty string.$",
76                ):
77                    self._create_params()
78
79    def test_init_raises_error_if_min_nodes_is_less_than_1(self) -> None:
80        for min_nodes in [0, -1, -5]:
81            with self.subTest(min_nodes=min_nodes):
82                self._min_nodes = min_nodes
83
84                with self.assertRaisesRegex(
85                    ValueError,
86                    rf"^The minimum number of rendezvous nodes \({min_nodes}\) must be greater "
87                    rf"than zero.$",
88                ):
89                    self._create_params()
90
91    def test_init_raises_error_if_max_nodes_is_less_than_min_nodes(self) -> None:
92        for max_nodes in [2, 1, -2]:
93            with self.subTest(max_nodes=max_nodes):
94                self._max_nodes = max_nodes
95
96                with self.assertRaisesRegex(
97                    ValueError,
98                    rf"^The maximum number of rendezvous nodes \({max_nodes}\) must be greater "
99                    "than or equal to the minimum number of rendezvous nodes "
100                    rf"\({self._min_nodes}\).$",
101                ):
102                    self._create_params()
103
104    def test_get_returns_none_if_key_does_not_exist(self) -> None:
105        params = self._create_params()
106
107        self.assertIsNone(params.get("dummy_param"))
108
109    def test_get_returns_default_if_key_does_not_exist(self) -> None:
110        params = self._create_params()
111
112        self.assertEqual(params.get("dummy_param", default="x"), "x")
113
114    def test_get_as_bool_returns_none_if_key_does_not_exist(self) -> None:
115        params = self._create_params()
116
117        self.assertIsNone(params.get_as_bool("dummy_param"))
118
119    def test_get_as_bool_returns_default_if_key_does_not_exist(self) -> None:
120        params = self._create_params()
121
122        self.assertTrue(params.get_as_bool("dummy_param", default=True))
123
124    def test_get_as_bool_returns_true_if_value_represents_true(self) -> None:
125        for value in ["1", "True", "tRue", "T", "t", "yEs", "Y", 1, True]:
126            with self.subTest(value=value):
127                self._kwargs["dummy_param"] = value
128
129                params = self._create_params()
130
131                self.assertTrue(params.get_as_bool("dummy_param"))
132
133    def test_get_as_bool_returns_false_if_value_represents_false(self) -> None:
134        for value in ["0", "False", "faLse", "F", "f", "nO", "N", 0, False]:
135            with self.subTest(value=value):
136                self._kwargs["dummy_param"] = value
137
138                params = self._create_params()
139
140                self.assertFalse(params.get_as_bool("dummy_param"))
141
142    def test_get_as_bool_raises_error_if_value_is_invalid(self) -> None:
143        for value in ["01", "Flse", "Ture", "g", "4", "_", "truefalse", 2, -1]:
144            with self.subTest(value=value):
145                self._kwargs["dummy_param"] = value
146
147                params = self._create_params()
148
149                with self.assertRaisesRegex(
150                    ValueError,
151                    r"^The rendezvous configuration option 'dummy_param' does not represent a "
152                    r"valid boolean value.$",
153                ):
154                    params.get_as_bool("dummy_param")
155
156    def test_get_as_int_returns_none_if_key_does_not_exist(self) -> None:
157        params = self._create_params()
158
159        self.assertIsNone(params.get_as_int("dummy_param"))
160
161    def test_get_as_int_returns_default_if_key_does_not_exist(self) -> None:
162        params = self._create_params()
163
164        self.assertEqual(params.get_as_int("dummy_param", default=5), 5)
165
166    def test_get_as_int_returns_integer_if_value_represents_integer(self) -> None:
167        for value in ["0", "-10", "5", "  4", "4  ", " 4 ", 0, -4, 3]:
168            with self.subTest(value=value):
169                self._kwargs["dummy_param"] = value
170
171                params = self._create_params()
172
173                self.assertEqual(
174                    params.get_as_int("dummy_param"), int(cast(SupportsInt, value))
175                )
176
177    def test_get_as_int_raises_error_if_value_is_invalid(self) -> None:
178        for value in ["a", "0a", "3b", "abc"]:
179            with self.subTest(value=value):
180                self._kwargs["dummy_param"] = value
181
182                params = self._create_params()
183
184                with self.assertRaisesRegex(
185                    ValueError,
186                    r"^The rendezvous configuration option 'dummy_param' does not represent a "
187                    r"valid integer value.$",
188                ):
189                    params.get_as_int("dummy_param")
190
191
192class _DummyRendezvousHandler(RendezvousHandler):
193    def __init__(self, params: RendezvousParameters) -> None:
194        self.params = params
195
196    def get_backend(self) -> str:
197        return "dummy_backend"
198
199    def next_rendezvous(self) -> RendezvousInfo:
200        raise NotImplementedError
201
202    def is_closed(self) -> bool:
203        return False
204
205    def set_closed(self) -> None:
206        pass
207
208    def num_nodes_waiting(self) -> int:
209        return 0
210
211    def get_run_id(self) -> str:
212        return ""
213
214    def shutdown(self) -> bool:
215        return False
216
217
218class RendezvousHandlerRegistryTest(TestCase):
219    def setUp(self) -> None:
220        self._params = RendezvousParameters(
221            backend="dummy_backend",
222            endpoint="dummy_endpoint",
223            run_id="dummy_run_id",
224            min_nodes=1,
225            max_nodes=1,
226        )
227
228        self._registry = RendezvousHandlerRegistry()
229
230    @staticmethod
231    def _create_handler(params: RendezvousParameters) -> RendezvousHandler:
232        return _DummyRendezvousHandler(params)
233
234    def test_register_registers_once_if_called_twice_with_same_creator(self) -> None:
235        self._registry.register("dummy_backend", self._create_handler)
236        self._registry.register("dummy_backend", self._create_handler)
237
238    def test_register_raises_error_if_called_twice_with_different_creators(
239        self,
240    ) -> None:
241        self._registry.register("dummy_backend", self._create_handler)
242
243        other_create_handler = lambda p: _DummyRendezvousHandler(p)  # noqa: E731
244
245        with self.assertRaisesRegex(
246            ValueError,
247            r"^The rendezvous backend 'dummy_backend' cannot be registered with "
248            rf"'{other_create_handler}' as it is already registered with '{self._create_handler}'.$",
249        ):
250            self._registry.register("dummy_backend", other_create_handler)
251
252    def test_create_handler_returns_handler(self) -> None:
253        self._registry.register("dummy_backend", self._create_handler)
254
255        handler = self._registry.create_handler(self._params)
256
257        self.assertIsInstance(handler, _DummyRendezvousHandler)
258
259        self.assertIs(handler.params, self._params)
260
261    def test_create_handler_raises_error_if_backend_is_not_registered(self) -> None:
262        with self.assertRaisesRegex(
263            ValueError,
264            r"^The rendezvous backend 'dummy_backend' is not registered. Did you forget to call "
265            r"`register`\?$",
266        ):
267            self._registry.create_handler(self._params)
268
269    def test_create_handler_raises_error_if_backend_names_do_not_match(self) -> None:
270        self._registry.register("dummy_backend_2", self._create_handler)
271
272        with self.assertRaisesRegex(
273            RuntimeError,
274            r"^The rendezvous backend 'dummy_backend' does not match the requested backend "
275            r"'dummy_backend_2'.$",
276        ):
277            self._params.backend = "dummy_backend_2"
278
279            self._registry.create_handler(self._params)
280