xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/multi_worker_util_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for multi_worker_util."""
16
17from tensorflow.core.protobuf import cluster_pb2
18from tensorflow.python.distribute import multi_worker_util
19from tensorflow.python.eager import test
20from tensorflow.python.training import server_lib
21
22
23class NormalizeClusterSpecTest(test.TestCase):
24
25  def assert_same_cluster(self, lhs, rhs):
26    self.assertEqual(
27        server_lib.ClusterSpec(lhs).as_dict(),
28        server_lib.ClusterSpec(rhs).as_dict())
29
30  def testDictAsInput(self):
31    cluster_spec = {
32        "chief": ["127.0.0.1:1234"],
33        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
34        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
35    }
36    self.assert_same_cluster(
37        cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec))
38
39  def testClusterDefAsInput(self):
40    cluster_def = cluster_pb2.ClusterDef()
41    job = cluster_def.job.add()
42    job.name = "chief"
43    job.tasks[0] = "127.0.0.1:1234"
44
45    job = cluster_def.job.add()
46    job.name = "worker"
47    job.tasks[0] = "127.0.0.1:8964"
48    job.tasks[1] = "127.0.0.1:2333"
49
50    job = cluster_def.job.add()
51    job.name = "ps"
52    job.tasks[0] = "127.0.0.1:1926"
53    job.tasks[1] = "127.0.0.1:3141"
54
55    self.assert_same_cluster(
56        cluster_def, multi_worker_util.normalize_cluster_spec(cluster_def))
57
58  def testClusterSpecAsInput(self):
59    cluster_spec = server_lib.ClusterSpec({
60        "chief": ["127.0.0.1:1234"],
61        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
62        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
63    })
64    self.assert_same_cluster(
65        cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec))
66
67  def testUnexpectedInput(self):
68    cluster_spec = ["127.0.0.1:8964", "127.0.0.1:2333"]
69
70    with self.assertRaisesRegex(
71        ValueError,
72        "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
73        "`tf.train.ClusterDef` object"):
74      multi_worker_util.normalize_cluster_spec(cluster_spec)
75
76
77class IsChiefTest(test.TestCase):
78
79  def testClusterWithChief(self):
80    cluster_spec = {
81        "chief": ["127.0.0.1:1234"],
82        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
83        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
84    }
85    self.assertTrue(multi_worker_util.is_chief(cluster_spec, "chief", 0))
86    self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 0))
87
88  def testClusterWithoutChief(self):
89    cluster_spec = {"worker": ["127.0.0.1:8964", "127.0.0.1:2333"]}
90    self.assertTrue(multi_worker_util.is_chief(cluster_spec, "worker", 0))
91    self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1))
92
93    with self.assertRaisesRegex(
94        ValueError, "`task_type` 'chief' not found in cluster_spec."):
95      multi_worker_util.is_chief(cluster_spec, "chief", 0)
96
97    with self.assertRaisesRegex(
98        ValueError, "The `task_id` 2 exceeds the maximum id of worker."):
99      multi_worker_util.is_chief(cluster_spec, "worker", 2)
100
101  def testEvaluatorIsChief(self):
102    cluster_spec = {
103        "chief": ["127.0.0.1:1234"],
104        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
105        "evaluator": ["127.0.0.1:2019"]
106    }
107    self.assertTrue(multi_worker_util.is_chief(cluster_spec, "evaluator", 0))
108
109
110class NumWorkersTest(test.TestCase):
111
112  def testCountWorker(self):
113    cluster_spec = {
114        "chief": ["127.0.0.1:1234"],
115        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
116        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
117    }
118    self.assertEqual(
119        multi_worker_util.worker_count(cluster_spec, task_type="chief"), 3)
120    self.assertEqual(
121        multi_worker_util.worker_count(cluster_spec, task_type="worker"), 3)
122
123  def testCountEvaluator(self):
124    cluster_spec = {
125        "chief": ["127.0.0.1:1234"],
126        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
127        "evaluator": ["127.0.0.1:7566"]
128    }
129    self.assertEqual(
130        multi_worker_util.worker_count(cluster_spec, task_type="evaluator"), 1)
131
132  def testTaskTypeNotFound(self):
133    cluster_spec = {}
134    with self.assertRaisesRegex(
135        ValueError, "`task_type` 'worker' not found in cluster_spec."):
136      multi_worker_util.worker_count(cluster_spec, task_type="worker")
137
138  def testCountPs(self):
139    cluster_spec = {
140        "chief": ["127.0.0.1:1234"],
141        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
142    }
143    # A "ps" job shouldn't call this method.
144    with self.assertRaisesRegex(ValueError, "Unexpected `task_type` 'ps'"):
145      multi_worker_util.worker_count(cluster_spec, task_type="ps")
146
147
148class IdInClusterTest(test.TestCase):
149
150  def testChiefId(self):
151    cluster_spec = {
152        "chief": ["127.0.0.1:1234"],
153        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
154        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
155    }
156    self.assertEqual(
157        multi_worker_util.id_in_cluster(cluster_spec, "chief", 0), 0)
158
159  def testWorkerId(self):
160    cluster_spec = {
161        "chief": ["127.0.0.1:1234"],
162        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
163        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
164    }
165    self.assertEqual(
166        multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 2)
167
168    cluster_spec = {
169        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
170        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
171    }
172    self.assertEqual(
173        multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 1)
174
175  def testEvaluatorId(self):
176    cluster_spec = {
177        "chief": ["127.0.0.1:1234"],
178        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
179        "evaluator": ["127.0.0.1:7566"]
180    }
181    self.assertEqual(
182        multi_worker_util.id_in_cluster(cluster_spec, "evaluator", 0), 0)
183
184  def testPsId(self):
185    cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]}
186    with self.assertRaisesRegex(ValueError,
187                                "There is no id for task_type 'ps'"):
188      multi_worker_util.id_in_cluster(cluster_spec, "ps", 0)
189
190  def testMultipleChiefs(self):
191    cluster_spec = {
192        "chief": ["127.0.0.1:8258", "127.0.0.1:7566"],
193    }
194    with self.assertRaisesRegex(ValueError,
195                                "There must be at most one 'chief' job."):
196      multi_worker_util.id_in_cluster(cluster_spec, "chief", 0)
197
198
199class CollectiveLeaderTest(test.TestCase):
200
201  def testChiefAsLeader(self):
202    cluster_spec = {
203        "chief": ["127.0.0.1:1234"],
204        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
205        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
206    }
207    self.assertEqual(
208        multi_worker_util.collective_leader(cluster_spec, "worker", 0),
209        "/job:chief/replica:0/task:0")
210
211  def testWorkerAsLeader(self):
212    cluster_spec = {
213        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
214        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
215    }
216    self.assertEqual(
217        multi_worker_util.collective_leader(cluster_spec, "worker", 1),
218        "/job:worker/replica:0/task:0")
219
220  def testLeaderForEvaluator(self):
221    cluster_spec = {
222        "chief": ["127.0.0.1:1234"],
223        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
224        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"],
225        "evaluator": ["127.0.0.1:2019"]
226    }
227    self.assertEqual(
228        multi_worker_util.collective_leader(cluster_spec, "evaluator", 0), "")
229
230  def testLocalLeader(self):
231    cluster_spec = {}
232    self.assertEqual(
233        multi_worker_util.collective_leader(cluster_spec, None, 0), "")
234
235
236# Most of the validation logic is tested by above tests except for some.
237class ClusterSpecValidationTest(test.TestCase):
238
239  def testEvaluatorNotInCluster(self):
240    cluster_spec = {
241        "chief": ["127.0.0.1:1234"],
242        "worker": ["127.0.0.1:8964", "127.0.0.1:2333"],
243        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
244    }
245    multi_worker_util._validate_cluster_spec(cluster_spec, "chief", 0)
246    multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0)
247    multi_worker_util._validate_cluster_spec(cluster_spec, "ps", 0)
248    multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0)
249
250  def testWorkerNotInCluster(self):
251    cluster_spec = {
252        "chief": ["127.0.0.1:1234"],
253        "ps": ["127.0.0.1:1926", "127.0.0.1:3141"]
254    }
255    multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0)
256    with self.assertRaisesRegex(
257        ValueError, "`task_type` 'worker' not found in cluster_spec."):
258      multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0)
259
260
261if __name__ == "__main__":
262  test.main()
263