1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for multi_worker_util.""" 16 17from tensorflow.core.protobuf import cluster_pb2 18from tensorflow.python.distribute import multi_worker_util 19from tensorflow.python.eager import test 20from tensorflow.python.training import server_lib 21 22 23class NormalizeClusterSpecTest(test.TestCase): 24 25 def assert_same_cluster(self, lhs, rhs): 26 self.assertEqual( 27 server_lib.ClusterSpec(lhs).as_dict(), 28 server_lib.ClusterSpec(rhs).as_dict()) 29 30 def testDictAsInput(self): 31 cluster_spec = { 32 "chief": ["127.0.0.1:1234"], 33 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 34 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 35 } 36 self.assert_same_cluster( 37 cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec)) 38 39 def testClusterDefAsInput(self): 40 cluster_def = cluster_pb2.ClusterDef() 41 job = cluster_def.job.add() 42 job.name = "chief" 43 job.tasks[0] = "127.0.0.1:1234" 44 45 job = cluster_def.job.add() 46 job.name = "worker" 47 job.tasks[0] = "127.0.0.1:8964" 48 job.tasks[1] = "127.0.0.1:2333" 49 50 job = cluster_def.job.add() 51 job.name = "ps" 52 job.tasks[0] = "127.0.0.1:1926" 53 job.tasks[1] = "127.0.0.1:3141" 54 55 self.assert_same_cluster( 56 cluster_def, multi_worker_util.normalize_cluster_spec(cluster_def)) 57 58 def testClusterSpecAsInput(self): 59 cluster_spec = server_lib.ClusterSpec({ 60 "chief": ["127.0.0.1:1234"], 61 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 62 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 63 }) 64 self.assert_same_cluster( 65 cluster_spec, multi_worker_util.normalize_cluster_spec(cluster_spec)) 66 67 def testUnexpectedInput(self): 68 cluster_spec = ["127.0.0.1:8964", "127.0.0.1:2333"] 69 70 with self.assertRaisesRegex( 71 ValueError, 72 "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " 73 "`tf.train.ClusterDef` object"): 74 multi_worker_util.normalize_cluster_spec(cluster_spec) 75 76 77class IsChiefTest(test.TestCase): 78 79 def testClusterWithChief(self): 80 cluster_spec = { 81 "chief": ["127.0.0.1:1234"], 82 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 83 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 84 } 85 self.assertTrue(multi_worker_util.is_chief(cluster_spec, "chief", 0)) 86 self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 0)) 87 88 def testClusterWithoutChief(self): 89 cluster_spec = {"worker": ["127.0.0.1:8964", "127.0.0.1:2333"]} 90 self.assertTrue(multi_worker_util.is_chief(cluster_spec, "worker", 0)) 91 self.assertFalse(multi_worker_util.is_chief(cluster_spec, "worker", 1)) 92 93 with self.assertRaisesRegex( 94 ValueError, "`task_type` 'chief' not found in cluster_spec."): 95 multi_worker_util.is_chief(cluster_spec, "chief", 0) 96 97 with self.assertRaisesRegex( 98 ValueError, "The `task_id` 2 exceeds the maximum id of worker."): 99 multi_worker_util.is_chief(cluster_spec, "worker", 2) 100 101 def testEvaluatorIsChief(self): 102 cluster_spec = { 103 "chief": ["127.0.0.1:1234"], 104 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 105 "evaluator": ["127.0.0.1:2019"] 106 } 107 self.assertTrue(multi_worker_util.is_chief(cluster_spec, "evaluator", 0)) 108 109 110class NumWorkersTest(test.TestCase): 111 112 def testCountWorker(self): 113 cluster_spec = { 114 "chief": ["127.0.0.1:1234"], 115 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 116 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 117 } 118 self.assertEqual( 119 multi_worker_util.worker_count(cluster_spec, task_type="chief"), 3) 120 self.assertEqual( 121 multi_worker_util.worker_count(cluster_spec, task_type="worker"), 3) 122 123 def testCountEvaluator(self): 124 cluster_spec = { 125 "chief": ["127.0.0.1:1234"], 126 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 127 "evaluator": ["127.0.0.1:7566"] 128 } 129 self.assertEqual( 130 multi_worker_util.worker_count(cluster_spec, task_type="evaluator"), 1) 131 132 def testTaskTypeNotFound(self): 133 cluster_spec = {} 134 with self.assertRaisesRegex( 135 ValueError, "`task_type` 'worker' not found in cluster_spec."): 136 multi_worker_util.worker_count(cluster_spec, task_type="worker") 137 138 def testCountPs(self): 139 cluster_spec = { 140 "chief": ["127.0.0.1:1234"], 141 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 142 } 143 # A "ps" job shouldn't call this method. 144 with self.assertRaisesRegex(ValueError, "Unexpected `task_type` 'ps'"): 145 multi_worker_util.worker_count(cluster_spec, task_type="ps") 146 147 148class IdInClusterTest(test.TestCase): 149 150 def testChiefId(self): 151 cluster_spec = { 152 "chief": ["127.0.0.1:1234"], 153 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 154 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 155 } 156 self.assertEqual( 157 multi_worker_util.id_in_cluster(cluster_spec, "chief", 0), 0) 158 159 def testWorkerId(self): 160 cluster_spec = { 161 "chief": ["127.0.0.1:1234"], 162 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 163 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 164 } 165 self.assertEqual( 166 multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 2) 167 168 cluster_spec = { 169 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 170 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 171 } 172 self.assertEqual( 173 multi_worker_util.id_in_cluster(cluster_spec, "worker", 1), 1) 174 175 def testEvaluatorId(self): 176 cluster_spec = { 177 "chief": ["127.0.0.1:1234"], 178 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 179 "evaluator": ["127.0.0.1:7566"] 180 } 181 self.assertEqual( 182 multi_worker_util.id_in_cluster(cluster_spec, "evaluator", 0), 0) 183 184 def testPsId(self): 185 cluster_spec = {"chief": ["127.0.0.1:1234"], "ps": ["127.0.0.1:7566"]} 186 with self.assertRaisesRegex(ValueError, 187 "There is no id for task_type 'ps'"): 188 multi_worker_util.id_in_cluster(cluster_spec, "ps", 0) 189 190 def testMultipleChiefs(self): 191 cluster_spec = { 192 "chief": ["127.0.0.1:8258", "127.0.0.1:7566"], 193 } 194 with self.assertRaisesRegex(ValueError, 195 "There must be at most one 'chief' job."): 196 multi_worker_util.id_in_cluster(cluster_spec, "chief", 0) 197 198 199class CollectiveLeaderTest(test.TestCase): 200 201 def testChiefAsLeader(self): 202 cluster_spec = { 203 "chief": ["127.0.0.1:1234"], 204 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 205 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 206 } 207 self.assertEqual( 208 multi_worker_util.collective_leader(cluster_spec, "worker", 0), 209 "/job:chief/replica:0/task:0") 210 211 def testWorkerAsLeader(self): 212 cluster_spec = { 213 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 214 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 215 } 216 self.assertEqual( 217 multi_worker_util.collective_leader(cluster_spec, "worker", 1), 218 "/job:worker/replica:0/task:0") 219 220 def testLeaderForEvaluator(self): 221 cluster_spec = { 222 "chief": ["127.0.0.1:1234"], 223 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 224 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"], 225 "evaluator": ["127.0.0.1:2019"] 226 } 227 self.assertEqual( 228 multi_worker_util.collective_leader(cluster_spec, "evaluator", 0), "") 229 230 def testLocalLeader(self): 231 cluster_spec = {} 232 self.assertEqual( 233 multi_worker_util.collective_leader(cluster_spec, None, 0), "") 234 235 236# Most of the validation logic is tested by above tests except for some. 237class ClusterSpecValidationTest(test.TestCase): 238 239 def testEvaluatorNotInCluster(self): 240 cluster_spec = { 241 "chief": ["127.0.0.1:1234"], 242 "worker": ["127.0.0.1:8964", "127.0.0.1:2333"], 243 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 244 } 245 multi_worker_util._validate_cluster_spec(cluster_spec, "chief", 0) 246 multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0) 247 multi_worker_util._validate_cluster_spec(cluster_spec, "ps", 0) 248 multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0) 249 250 def testWorkerNotInCluster(self): 251 cluster_spec = { 252 "chief": ["127.0.0.1:1234"], 253 "ps": ["127.0.0.1:1926", "127.0.0.1:3141"] 254 } 255 multi_worker_util._validate_cluster_spec(cluster_spec, "evaluator", 0) 256 with self.assertRaisesRegex( 257 ValueError, "`task_type` 'worker' not found in cluster_spec."): 258 multi_worker_util._validate_cluster_spec(cluster_spec, "worker", 0) 259 260 261if __name__ == "__main__": 262 test.main() 263