xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/multi_worker_test_base_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for multi-process clusters."""
16
17from tensorflow.python.distribute import multi_process_runner
18from tensorflow.python.distribute import multi_worker_test_base
19from tensorflow.python.eager import context
20from tensorflow.python.eager import remote
21from tensorflow.python.eager import test
22
23
24class MultiProcessClusterTest(test.TestCase):
25
26  def setUp(self):
27    super(MultiProcessClusterTest, self).setUp()
28    self._cluster = multi_worker_test_base.create_multi_process_cluster(
29        num_workers=2, num_ps=1, has_chief=True, rpc_layer="grpc")
30    remote.connect_to_cluster(
31        self._cluster.cluster_resolver.cluster_spec(), protocol="grpc")
32    context.ensure_initialized()
33
34  def testClusterIsAlive(self):
35    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
36    self.assertTrue(context.check_alive("/job:worker/replica:0/task:1"))
37    self.assertTrue(context.check_alive("/job:ps/replica:0/task:0"))
38    self.assertTrue(context.check_alive("/job:chief/replica:0/task:0"))
39
40  def testKillAndStartTask(self):
41    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
42
43    # It is not allowed to start a task before killing it.
44    with self.assertRaises(ValueError):
45      self._cluster.start_task("worker", 0)
46
47    self._cluster.kill_task("worker", 0)
48    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
49
50    # The task is already killed.
51    with self.assertRaises(ValueError):
52      self._cluster.kill_task("worker", 0)
53
54    self._cluster.start_task("worker", 0)
55
56    # Without a call to update_server_def, the next check_alive will return
57    # False. Alternatively sleeping for 2 seconds here also works.
58    context.context().update_server_def(context.get_server_def())
59
60    self.assertTrue(context.check_alive("/job:worker/replica:0/task:0"))
61
62  def testStop(self):
63    self._cluster.stop()
64    self.assertFalse(context.check_alive("/job:worker/replica:0/task:0"))
65    self.assertFalse(context.check_alive("/job:worker/replica:0/task:1"))
66    self.assertFalse(context.check_alive("/job:ps/replica:0/task:0"))
67    self.assertFalse(context.check_alive("/job:chief/replica:0/task:0"))
68
69  def testClusterResolverProperty(self):
70    cluster_spec = self._cluster.cluster_resolver.cluster_spec().as_dict()
71
72    self.assertEqual(len(cluster_spec["worker"]), 2)
73    self.assertEqual(len(cluster_spec["ps"]), 1)
74    self.assertEqual(len(cluster_spec["chief"]), 1)
75
76
77if __name__ == "__main__":
78  multi_process_runner.test_main()
79