1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for GCEClusterResolver."""
16
17from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
18from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver
19from tensorflow.python.platform import test
20from tensorflow.python.training import server_lib
21
22
23mock = test.mock
24
25
26class GCEClusterResolverTest(test.TestCase):
27
28  def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
29    self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
30    self.assertProtoEquals(
31        expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
32    self.assertProtoEquals(
33        expected_proto,
34        server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
35    self.assertProtoEquals(
36        expected_proto,
37        server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
38
39  def standard_mock_instance_groups(self, instance_map=None):
40    if instance_map is None:
41      instance_map = [
42          {'instance': 'https://gce.example.com/res/gce-instance-1'}
43      ]
44
45    mock_instance_group_request = mock.MagicMock()
46    mock_instance_group_request.execute.return_value = {
47        'items': instance_map
48    }
49
50    service_attrs = {
51        'listInstances.return_value': mock_instance_group_request,
52        'listInstances_next.return_value': None,
53    }
54    mock_instance_groups = mock.Mock(**service_attrs)
55    return mock_instance_groups
56
57  def standard_mock_instances(self, instance_to_ip_map=None):
58    if instance_to_ip_map is None:
59      instance_to_ip_map = {
60          'gce-instance-1': '10.123.45.67'
61      }
62
63    mock_get_request = mock.MagicMock()
64    mock_get_request.execute.return_value = {
65        'networkInterfaces': [
66            {'networkIP': '10.123.45.67'}
67        ]
68    }
69
70    def get_side_effect(project, zone, instance):
71      del project, zone  # Unused
72
73      if instance in instance_to_ip_map:
74        mock_get_request = mock.MagicMock()
75        mock_get_request.execute.return_value = {
76            'networkInterfaces': [
77                {'networkIP': instance_to_ip_map[instance]}
78            ]
79        }
80        return mock_get_request
81      else:
82        raise RuntimeError('Instance %s not found!' % instance)
83
84    service_attrs = {
85        'get.side_effect': get_side_effect,
86    }
87    mock_instances = mock.MagicMock(**service_attrs)
88    return mock_instances
89
90  def standard_mock_service_client(
91      self,
92      mock_instance_groups=None,
93      mock_instances=None):
94
95    if mock_instance_groups is None:
96      mock_instance_groups = self.standard_mock_instance_groups()
97    if mock_instances is None:
98      mock_instances = self.standard_mock_instances()
99
100    mock_client = mock.MagicMock()
101    mock_client.instanceGroups.return_value = mock_instance_groups
102    mock_client.instances.return_value = mock_instances
103    return mock_client
104
105  def gen_standard_mock_service_client(self, instances=None):
106    name_to_ip = {}
107    instance_list = []
108    for instance in instances:
109      name_to_ip[instance['name']] = instance['ip']
110      instance_list.append({
111          'instance': 'https://gce.example.com/gce/res/' + instance['name']
112      })
113
114    mock_instance = self.standard_mock_instances(name_to_ip)
115    mock_instance_group = self.standard_mock_instance_groups(instance_list)
116
117    return self.standard_mock_service_client(mock_instance_group, mock_instance)
118
119  def testSimpleSuccessfulRetrieval(self):
120    gce_cluster_resolver = GCEClusterResolver(
121        project='test-project',
122        zone='us-east1-d',
123        instance_group='test-instance-group',
124        port=8470,
125        credentials=None,
126        service=self.standard_mock_service_client())
127
128    actual_cluster_spec = gce_cluster_resolver.cluster_spec()
129    expected_proto = """
130    job { name: 'worker' tasks { key: 0 value: '10.123.45.67:8470' } }
131    """
132    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
133
134  def testMasterRetrieval(self):
135    gce_cluster_resolver = GCEClusterResolver(
136        project='test-project',
137        zone='us-east1-d',
138        instance_group='test-instance-group',
139        task_id=0,
140        port=8470,
141        credentials=None,
142        service=self.standard_mock_service_client())
143    self.assertEqual(gce_cluster_resolver.master(), 'grpc://10.123.45.67:8470')
144
145  def testMasterRetrievalWithCustomTasks(self):
146    name_to_ip = [
147        {'name': 'instance1', 'ip': '10.1.2.3'},
148        {'name': 'instance2', 'ip': '10.2.3.4'},
149        {'name': 'instance3', 'ip': '10.3.4.5'},
150    ]
151
152    gce_cluster_resolver = GCEClusterResolver(
153        project='test-project',
154        zone='us-east1-d',
155        instance_group='test-instance-group',
156        port=8470,
157        credentials=None,
158        service=self.gen_standard_mock_service_client(name_to_ip))
159
160    self.assertEqual(
161        gce_cluster_resolver.master('worker', 2, 'test'),
162        'test://10.3.4.5:8470')
163
164  def testOverrideParameters(self):
165    name_to_ip = [
166        {'name': 'instance1', 'ip': '10.1.2.3'},
167        {'name': 'instance2', 'ip': '10.2.3.4'},
168        {'name': 'instance3', 'ip': '10.3.4.5'},
169    ]
170
171    gce_cluster_resolver = GCEClusterResolver(
172        project='test-project',
173        zone='us-east1-d',
174        instance_group='test-instance-group',
175        task_type='testworker',
176        port=8470,
177        credentials=None,
178        service=self.gen_standard_mock_service_client(name_to_ip))
179
180    gce_cluster_resolver.task_id = 1
181    gce_cluster_resolver.rpc_layer = 'test'
182
183    self.assertEqual(gce_cluster_resolver.task_type, 'testworker')
184    self.assertEqual(gce_cluster_resolver.task_id, 1)
185    self.assertEqual(gce_cluster_resolver.rpc_layer, 'test')
186    self.assertEqual(gce_cluster_resolver.master(), 'test://10.2.3.4:8470')
187
188  def testOverrideParametersWithZeroOrEmpty(self):
189    name_to_ip = [
190        {'name': 'instance1', 'ip': '10.1.2.3'},
191        {'name': 'instance2', 'ip': '10.2.3.4'},
192        {'name': 'instance3', 'ip': '10.3.4.5'},
193    ]
194
195    gce_cluster_resolver = GCEClusterResolver(
196        project='test-project',
197        zone='us-east1-d',
198        instance_group='test-instance-group',
199        task_type='',
200        task_id=1,
201        port=8470,
202        credentials=None,
203        service=self.gen_standard_mock_service_client(name_to_ip))
204
205    self.assertEqual(gce_cluster_resolver.master(
206        task_type='', task_id=0), 'grpc://10.1.2.3:8470')
207
208  def testCustomJobNameAndPortRetrieval(self):
209    gce_cluster_resolver = GCEClusterResolver(
210        project='test-project',
211        zone='us-east1-d',
212        instance_group='test-instance-group',
213        task_type='custom',
214        port=2222,
215        credentials=None,
216        service=self.standard_mock_service_client())
217
218    actual_cluster_spec = gce_cluster_resolver.cluster_spec()
219    expected_proto = """
220    job { name: 'custom' tasks { key: 0 value: '10.123.45.67:2222' } }
221    """
222    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
223
224  def testMultipleInstancesRetrieval(self):
225    name_to_ip = [
226        {'name': 'instance1', 'ip': '10.1.2.3'},
227        {'name': 'instance2', 'ip': '10.2.3.4'},
228        {'name': 'instance3', 'ip': '10.3.4.5'},
229    ]
230
231    gce_cluster_resolver = GCEClusterResolver(
232        project='test-project',
233        zone='us-east1-d',
234        instance_group='test-instance-group',
235        port=8470,
236        credentials=None,
237        service=self.gen_standard_mock_service_client(name_to_ip))
238
239    actual_cluster_spec = gce_cluster_resolver.cluster_spec()
240    expected_proto = """
241    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' }
242                         tasks { key: 1 value: '10.2.3.4:8470' }
243                         tasks { key: 2 value: '10.3.4.5:8470' } }
244    """
245    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
246
247  def testUnionMultipleInstanceRetrieval(self):
248    worker1_name_to_ip = [
249        {'name': 'instance1', 'ip': '10.1.2.3'},
250        {'name': 'instance2', 'ip': '10.2.3.4'},
251        {'name': 'instance3', 'ip': '10.3.4.5'},
252    ]
253
254    worker2_name_to_ip = [
255        {'name': 'instance4', 'ip': '10.4.5.6'},
256        {'name': 'instance5', 'ip': '10.5.6.7'},
257        {'name': 'instance6', 'ip': '10.6.7.8'},
258    ]
259
260    ps_name_to_ip = [
261        {'name': 'ps1', 'ip': '10.100.1.2'},
262        {'name': 'ps2', 'ip': '10.100.2.3'},
263    ]
264
265    worker1_gce_cluster_resolver = GCEClusterResolver(
266        project='test-project',
267        zone='us-east1-d',
268        instance_group='test-instance-group',
269        task_type='worker',
270        port=8470,
271        credentials=None,
272        service=self.gen_standard_mock_service_client(worker1_name_to_ip))
273
274    worker2_gce_cluster_resolver = GCEClusterResolver(
275        project='test-project',
276        zone='us-east1-d',
277        instance_group='test-instance-group',
278        task_type='worker',
279        port=8470,
280        credentials=None,
281        service=self.gen_standard_mock_service_client(worker2_name_to_ip))
282
283    ps_gce_cluster_resolver = GCEClusterResolver(
284        project='test-project',
285        zone='us-east1-d',
286        instance_group='test-instance-group',
287        task_type='ps',
288        port=2222,
289        credentials=None,
290        service=self.gen_standard_mock_service_client(ps_name_to_ip))
291
292    union_cluster_resolver = UnionClusterResolver(worker1_gce_cluster_resolver,
293                                                  worker2_gce_cluster_resolver,
294                                                  ps_gce_cluster_resolver)
295
296    actual_cluster_spec = union_cluster_resolver.cluster_spec()
297    expected_proto = """
298    job { name: 'ps' tasks { key: 0 value: '10.100.1.2:2222' }
299                     tasks { key: 1 value: '10.100.2.3:2222' } }
300    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' }
301                         tasks { key: 1 value: '10.2.3.4:8470' }
302                         tasks { key: 2 value: '10.3.4.5:8470' }
303                         tasks { key: 3 value: '10.4.5.6:8470' }
304                         tasks { key: 4 value: '10.5.6.7:8470' }
305                         tasks { key: 5 value: '10.6.7.8:8470' } }
306    """
307    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
308
309  def testSettingTaskTypeRaiseError(self):
310    name_to_ip = [
311        {
312            'name': 'instance1',
313            'ip': '10.1.2.3'
314        },
315        {
316            'name': 'instance2',
317            'ip': '10.2.3.4'
318        },
319        {
320            'name': 'instance3',
321            'ip': '10.3.4.5'
322        },
323    ]
324
325    gce_cluster_resolver = GCEClusterResolver(
326        project='test-project',
327        zone='us-east1-d',
328        instance_group='test-instance-group',
329        task_type='testworker',
330        port=8470,
331        credentials=None,
332        service=self.gen_standard_mock_service_client(name_to_ip))
333
334    with self.assertRaisesRegex(
335        RuntimeError, 'You cannot reset the task_type '
336        'of the GCEClusterResolver after it has '
337        'been created.'):
338      gce_cluster_resolver.task_type = 'foobar'
339
340
341if __name__ == '__main__':
342  test.main()
343