1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for multi-process clusters.""" 16 17from tensorflow.python.distribute import multi_process_runner 18from tensorflow.python.distribute import multi_worker_test_base 19from tensorflow.python.eager import context 20from tensorflow.python.eager import remote 21from tensorflow.python.eager import test 22 23 24class MultiProcessClusterTest(test.TestCase): 25 26 def setUp(self): 27 super(MultiProcessClusterTest, self).setUp() 28 self._cluster = multi_worker_test_base.create_multi_process_cluster( 29 num_workers=2, num_ps=1, has_chief=True, rpc_layer="grpc") 30 remote.connect_to_cluster( 31 self._cluster.cluster_resolver.cluster_spec(), protocol="grpc") 32 context.ensure_initialized() 33 34 def testClusterIsAlive(self): 35 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 36 self.assertTrue(context.check_alive("/job:worker/replica:0/task:1")) 37 self.assertTrue(context.check_alive("/job:ps/replica:0/task:0")) 38 self.assertTrue(context.check_alive("/job:chief/replica:0/task:0")) 39 40 def testKillAndStartTask(self): 41 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 42 43 # It is not allowed to start a task before killing it. 44 with self.assertRaises(ValueError): 45 self._cluster.start_task("worker", 0) 46 47 self._cluster.kill_task("worker", 0) 48 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 49 50 # The task is already killed. 51 with self.assertRaises(ValueError): 52 self._cluster.kill_task("worker", 0) 53 54 self._cluster.start_task("worker", 0) 55 56 # Without a call to update_server_def, the next check_alive will return 57 # False. Alternatively sleeping for 2 seconds here also works. 58 context.context().update_server_def(context.get_server_def()) 59 60 self.assertTrue(context.check_alive("/job:worker/replica:0/task:0")) 61 62 def testStop(self): 63 self._cluster.stop() 64 self.assertFalse(context.check_alive("/job:worker/replica:0/task:0")) 65 self.assertFalse(context.check_alive("/job:worker/replica:0/task:1")) 66 self.assertFalse(context.check_alive("/job:ps/replica:0/task:0")) 67 self.assertFalse(context.check_alive("/job:chief/replica:0/task:0")) 68 69 def testClusterResolverProperty(self): 70 cluster_spec = self._cluster.cluster_resolver.cluster_spec().as_dict() 71 72 self.assertEqual(len(cluster_spec["worker"]), 2) 73 self.assertEqual(len(cluster_spec["ps"]), 1) 74 self.assertEqual(len(cluster_spec["chief"]), 1) 75 76 77if __name__ == "__main__": 78 multi_process_runner.test_main() 79