1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for GCEClusterResolver.""" 16 17from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver 18from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver 19from tensorflow.python.platform import test 20from tensorflow.python.training import server_lib 21 22 23mock = test.mock 24 25 26class GCEClusterResolverTest(test.TestCase): 27 28 def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): 29 self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) 30 self.assertProtoEquals( 31 expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) 32 self.assertProtoEquals( 33 expected_proto, 34 server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) 35 self.assertProtoEquals( 36 expected_proto, 37 server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) 38 39 def standard_mock_instance_groups(self, instance_map=None): 40 if instance_map is None: 41 instance_map = [ 42 {'instance': 'https://gce.example.com/res/gce-instance-1'} 43 ] 44 45 mock_instance_group_request = mock.MagicMock() 46 mock_instance_group_request.execute.return_value = { 47 'items': instance_map 48 } 49 50 service_attrs = { 51 'listInstances.return_value': mock_instance_group_request, 52 'listInstances_next.return_value': None, 53 } 54 mock_instance_groups = mock.Mock(**service_attrs) 55 return mock_instance_groups 56 57 def standard_mock_instances(self, instance_to_ip_map=None): 58 if instance_to_ip_map is None: 59 instance_to_ip_map = { 60 'gce-instance-1': '10.123.45.67' 61 } 62 63 mock_get_request = mock.MagicMock() 64 mock_get_request.execute.return_value = { 65 'networkInterfaces': [ 66 {'networkIP': '10.123.45.67'} 67 ] 68 } 69 70 def get_side_effect(project, zone, instance): 71 del project, zone # Unused 72 73 if instance in instance_to_ip_map: 74 mock_get_request = mock.MagicMock() 75 mock_get_request.execute.return_value = { 76 'networkInterfaces': [ 77 {'networkIP': instance_to_ip_map[instance]} 78 ] 79 } 80 return mock_get_request 81 else: 82 raise RuntimeError('Instance %s not found!' % instance) 83 84 service_attrs = { 85 'get.side_effect': get_side_effect, 86 } 87 mock_instances = mock.MagicMock(**service_attrs) 88 return mock_instances 89 90 def standard_mock_service_client( 91 self, 92 mock_instance_groups=None, 93 mock_instances=None): 94 95 if mock_instance_groups is None: 96 mock_instance_groups = self.standard_mock_instance_groups() 97 if mock_instances is None: 98 mock_instances = self.standard_mock_instances() 99 100 mock_client = mock.MagicMock() 101 mock_client.instanceGroups.return_value = mock_instance_groups 102 mock_client.instances.return_value = mock_instances 103 return mock_client 104 105 def gen_standard_mock_service_client(self, instances=None): 106 name_to_ip = {} 107 instance_list = [] 108 for instance in instances: 109 name_to_ip[instance['name']] = instance['ip'] 110 instance_list.append({ 111 'instance': 'https://gce.example.com/gce/res/' + instance['name'] 112 }) 113 114 mock_instance = self.standard_mock_instances(name_to_ip) 115 mock_instance_group = self.standard_mock_instance_groups(instance_list) 116 117 return self.standard_mock_service_client(mock_instance_group, mock_instance) 118 119 def testSimpleSuccessfulRetrieval(self): 120 gce_cluster_resolver = GCEClusterResolver( 121 project='test-project', 122 zone='us-east1-d', 123 instance_group='test-instance-group', 124 port=8470, 125 credentials=None, 126 service=self.standard_mock_service_client()) 127 128 actual_cluster_spec = gce_cluster_resolver.cluster_spec() 129 expected_proto = """ 130 job { name: 'worker' tasks { key: 0 value: '10.123.45.67:8470' } } 131 """ 132 self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) 133 134 def testMasterRetrieval(self): 135 gce_cluster_resolver = GCEClusterResolver( 136 project='test-project', 137 zone='us-east1-d', 138 instance_group='test-instance-group', 139 task_id=0, 140 port=8470, 141 credentials=None, 142 service=self.standard_mock_service_client()) 143 self.assertEqual(gce_cluster_resolver.master(), 'grpc://10.123.45.67:8470') 144 145 def testMasterRetrievalWithCustomTasks(self): 146 name_to_ip = [ 147 {'name': 'instance1', 'ip': '10.1.2.3'}, 148 {'name': 'instance2', 'ip': '10.2.3.4'}, 149 {'name': 'instance3', 'ip': '10.3.4.5'}, 150 ] 151 152 gce_cluster_resolver = GCEClusterResolver( 153 project='test-project', 154 zone='us-east1-d', 155 instance_group='test-instance-group', 156 port=8470, 157 credentials=None, 158 service=self.gen_standard_mock_service_client(name_to_ip)) 159 160 self.assertEqual( 161 gce_cluster_resolver.master('worker', 2, 'test'), 162 'test://10.3.4.5:8470') 163 164 def testOverrideParameters(self): 165 name_to_ip = [ 166 {'name': 'instance1', 'ip': '10.1.2.3'}, 167 {'name': 'instance2', 'ip': '10.2.3.4'}, 168 {'name': 'instance3', 'ip': '10.3.4.5'}, 169 ] 170 171 gce_cluster_resolver = GCEClusterResolver( 172 project='test-project', 173 zone='us-east1-d', 174 instance_group='test-instance-group', 175 task_type='testworker', 176 port=8470, 177 credentials=None, 178 service=self.gen_standard_mock_service_client(name_to_ip)) 179 180 gce_cluster_resolver.task_id = 1 181 gce_cluster_resolver.rpc_layer = 'test' 182 183 self.assertEqual(gce_cluster_resolver.task_type, 'testworker') 184 self.assertEqual(gce_cluster_resolver.task_id, 1) 185 self.assertEqual(gce_cluster_resolver.rpc_layer, 'test') 186 self.assertEqual(gce_cluster_resolver.master(), 'test://10.2.3.4:8470') 187 188 def testOverrideParametersWithZeroOrEmpty(self): 189 name_to_ip = [ 190 {'name': 'instance1', 'ip': '10.1.2.3'}, 191 {'name': 'instance2', 'ip': '10.2.3.4'}, 192 {'name': 'instance3', 'ip': '10.3.4.5'}, 193 ] 194 195 gce_cluster_resolver = GCEClusterResolver( 196 project='test-project', 197 zone='us-east1-d', 198 instance_group='test-instance-group', 199 task_type='', 200 task_id=1, 201 port=8470, 202 credentials=None, 203 service=self.gen_standard_mock_service_client(name_to_ip)) 204 205 self.assertEqual(gce_cluster_resolver.master( 206 task_type='', task_id=0), 'grpc://10.1.2.3:8470') 207 208 def testCustomJobNameAndPortRetrieval(self): 209 gce_cluster_resolver = GCEClusterResolver( 210 project='test-project', 211 zone='us-east1-d', 212 instance_group='test-instance-group', 213 task_type='custom', 214 port=2222, 215 credentials=None, 216 service=self.standard_mock_service_client()) 217 218 actual_cluster_spec = gce_cluster_resolver.cluster_spec() 219 expected_proto = """ 220 job { name: 'custom' tasks { key: 0 value: '10.123.45.67:2222' } } 221 """ 222 self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) 223 224 def testMultipleInstancesRetrieval(self): 225 name_to_ip = [ 226 {'name': 'instance1', 'ip': '10.1.2.3'}, 227 {'name': 'instance2', 'ip': '10.2.3.4'}, 228 {'name': 'instance3', 'ip': '10.3.4.5'}, 229 ] 230 231 gce_cluster_resolver = GCEClusterResolver( 232 project='test-project', 233 zone='us-east1-d', 234 instance_group='test-instance-group', 235 port=8470, 236 credentials=None, 237 service=self.gen_standard_mock_service_client(name_to_ip)) 238 239 actual_cluster_spec = gce_cluster_resolver.cluster_spec() 240 expected_proto = """ 241 job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } 242 tasks { key: 1 value: '10.2.3.4:8470' } 243 tasks { key: 2 value: '10.3.4.5:8470' } } 244 """ 245 self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) 246 247 def testUnionMultipleInstanceRetrieval(self): 248 worker1_name_to_ip = [ 249 {'name': 'instance1', 'ip': '10.1.2.3'}, 250 {'name': 'instance2', 'ip': '10.2.3.4'}, 251 {'name': 'instance3', 'ip': '10.3.4.5'}, 252 ] 253 254 worker2_name_to_ip = [ 255 {'name': 'instance4', 'ip': '10.4.5.6'}, 256 {'name': 'instance5', 'ip': '10.5.6.7'}, 257 {'name': 'instance6', 'ip': '10.6.7.8'}, 258 ] 259 260 ps_name_to_ip = [ 261 {'name': 'ps1', 'ip': '10.100.1.2'}, 262 {'name': 'ps2', 'ip': '10.100.2.3'}, 263 ] 264 265 worker1_gce_cluster_resolver = GCEClusterResolver( 266 project='test-project', 267 zone='us-east1-d', 268 instance_group='test-instance-group', 269 task_type='worker', 270 port=8470, 271 credentials=None, 272 service=self.gen_standard_mock_service_client(worker1_name_to_ip)) 273 274 worker2_gce_cluster_resolver = GCEClusterResolver( 275 project='test-project', 276 zone='us-east1-d', 277 instance_group='test-instance-group', 278 task_type='worker', 279 port=8470, 280 credentials=None, 281 service=self.gen_standard_mock_service_client(worker2_name_to_ip)) 282 283 ps_gce_cluster_resolver = GCEClusterResolver( 284 project='test-project', 285 zone='us-east1-d', 286 instance_group='test-instance-group', 287 task_type='ps', 288 port=2222, 289 credentials=None, 290 service=self.gen_standard_mock_service_client(ps_name_to_ip)) 291 292 union_cluster_resolver = UnionClusterResolver(worker1_gce_cluster_resolver, 293 worker2_gce_cluster_resolver, 294 ps_gce_cluster_resolver) 295 296 actual_cluster_spec = union_cluster_resolver.cluster_spec() 297 expected_proto = """ 298 job { name: 'ps' tasks { key: 0 value: '10.100.1.2:2222' } 299 tasks { key: 1 value: '10.100.2.3:2222' } } 300 job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } 301 tasks { key: 1 value: '10.2.3.4:8470' } 302 tasks { key: 2 value: '10.3.4.5:8470' } 303 tasks { key: 3 value: '10.4.5.6:8470' } 304 tasks { key: 4 value: '10.5.6.7:8470' } 305 tasks { key: 5 value: '10.6.7.8:8470' } } 306 """ 307 self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) 308 309 def testSettingTaskTypeRaiseError(self): 310 name_to_ip = [ 311 { 312 'name': 'instance1', 313 'ip': '10.1.2.3' 314 }, 315 { 316 'name': 'instance2', 317 'ip': '10.2.3.4' 318 }, 319 { 320 'name': 'instance3', 321 'ip': '10.3.4.5' 322 }, 323 ] 324 325 gce_cluster_resolver = GCEClusterResolver( 326 project='test-project', 327 zone='us-east1-d', 328 instance_group='test-instance-group', 329 task_type='testworker', 330 port=8470, 331 credentials=None, 332 service=self.gen_standard_mock_service_client(name_to_ip)) 333 334 with self.assertRaisesRegex( 335 RuntimeError, 'You cannot reset the task_type ' 336 'of the GCEClusterResolver after it has ' 337 'been created.'): 338 gce_cluster_resolver.task_type = 'foobar' 339 340 341if __name__ == '__main__': 342 test.main() 343