1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Cluster Resolvers for Kubernetes."""
16
17from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
18from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url
19from tensorflow.python.training import server_lib
20from tensorflow.python.util.tf_export import tf_export
21
22
23@tf_export('distribute.cluster_resolver.KubernetesClusterResolver')
24class KubernetesClusterResolver(ClusterResolver):
25  """ClusterResolver for Kubernetes.
26
27  This is an implementation of cluster resolvers for Kubernetes. When given the
28  the Kubernetes namespace and label selector for pods, we will retrieve the
29  pod IP addresses of all running pods matching the selector, and return a
30  ClusterSpec based on that information.
31
32  Note: it cannot retrieve `task_type`, `task_id` or `rpc_layer`. To use it
33  with some distribution strategies like
34  `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to
35  specify `task_type` and `task_id` by setting these attributes.
36
37  Usage example with tf.distribute.Strategy:
38
39    ```Python
40    # On worker 0
41    cluster_resolver = KubernetesClusterResolver(
42        {"worker": ["job-name=worker-cluster-a", "job-name=worker-cluster-b"]})
43    cluster_resolver.task_type = "worker"
44    cluster_resolver.task_id = 0
45    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
46        cluster_resolver=cluster_resolver)
47
48    # On worker 1
49    cluster_resolver = KubernetesClusterResolver(
50        {"worker": ["job-name=worker-cluster-a", "job-name=worker-cluster-b"]})
51    cluster_resolver.task_type = "worker"
52    cluster_resolver.task_id = 1
53    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
54        cluster_resolver=cluster_resolver)
55    ```
56  """
57
58  def __init__(self,
59               job_to_label_mapping=None,
60               tf_server_port=8470,
61               rpc_layer='grpc',
62               override_client=None):
63    """Initializes a new KubernetesClusterResolver.
64
65    This initializes a new Kubernetes ClusterResolver. The ClusterResolver
66    will attempt to talk to the Kubernetes master to retrieve all the instances
67    of pods matching a label selector.
68
69    Args:
70      job_to_label_mapping: A mapping of TensorFlow jobs to label selectors.
71        This allows users to specify many TensorFlow jobs in one Cluster
72        Resolver, and each job can have pods belong with different label
73        selectors. For example, a sample mapping might be
74        ```
75        {'worker': ['job-name=worker-cluster-a', 'job-name=worker-cluster-b'],
76         'ps': ['job-name=ps-1', 'job-name=ps-2']}
77        ```
78      tf_server_port: The port the TensorFlow server is listening on.
79      rpc_layer: (Optional) The RPC layer TensorFlow should use to communicate
80        between tasks in Kubernetes. Defaults to 'grpc'.
81      override_client: The Kubernetes client (usually automatically retrieved
82        using `from kubernetes import client as k8sclient`). If you pass this
83        in, you are responsible for setting Kubernetes credentials manually.
84
85    Raises:
86      ImportError: If the Kubernetes Python client is not installed and no
87        `override_client` is passed in.
88      RuntimeError: If autoresolve_task is not a boolean or a callable.
89    """
90    try:
91      from kubernetes import config as k8sconfig  # pylint: disable=g-import-not-at-top
92
93      k8sconfig.load_kube_config()
94    except ImportError:
95      if not override_client:
96        raise ImportError('The Kubernetes Python client must be installed '
97                          'before using the Kubernetes Cluster Resolver. '
98                          'To install the Kubernetes Python client, run '
99                          '`pip install kubernetes` on your command line.')
100
101    if not job_to_label_mapping:
102      job_to_label_mapping = {'worker': ['job-name=tensorflow']}
103
104    self._job_to_label_mapping = job_to_label_mapping
105    self._tf_server_port = tf_server_port
106    self._override_client = override_client
107
108    self.task_type = None
109    self.task_id = None
110    self.rpc_layer = rpc_layer
111
112  def master(self, task_type=None, task_id=None, rpc_layer=None):
113    """Returns the master address to use when creating a session.
114
115    You must have set the task_type and task_id object properties before
116    calling this function, or pass in the `task_type` and `task_id`
117    parameters when using this function. If you do both, the function parameters
118    will override the object properties.
119
120    Note: this is only useful for TensorFlow 1.x.
121
122    Args:
123      task_type: (Optional) The type of the TensorFlow task of the master.
124      task_id: (Optional) The index of the TensorFlow task of the master.
125      rpc_layer: (Optional) The RPC protocol for the given cluster.
126
127    Returns:
128      The name or URL of the session master.
129    """
130    task_type = task_type if task_type is not None else self.task_type
131    task_id = task_id if task_id is not None else self.task_id
132
133    if task_type is not None and task_id is not None:
134      return format_master_url(
135          self.cluster_spec().task_address(task_type, task_id),
136          rpc_layer or self.rpc_layer)
137
138    return ''
139
140  def cluster_spec(self):
141    """Returns a ClusterSpec object based on the latest info from Kubernetes.
142
143    We retrieve the information from the Kubernetes master every time this
144    method is called.
145
146    Returns:
147      A ClusterSpec containing host information returned from Kubernetes.
148
149    Raises:
150      RuntimeError: If any of the pods returned by the master is not in the
151        `Running` phase.
152    """
153    if self._override_client:
154      client = self._override_client
155    else:
156      from kubernetes import config as k8sconfig  # pylint: disable=g-import-not-at-top
157      from kubernetes import client as k8sclient  # pylint: disable=g-import-not-at-top
158
159      k8sconfig.load_kube_config()
160      client = k8sclient.CoreV1Api()
161
162    cluster_map = {}
163
164    for tf_job in self._job_to_label_mapping:
165      all_pods = []
166      for selector in self._job_to_label_mapping[tf_job]:
167        ret = client.list_pod_for_all_namespaces(label_selector=selector)
168        selected_pods = []
169
170        # Sort the list by the name to make sure it doesn't change call to call.
171        for pod in sorted(ret.items, key=lambda x: x.metadata.name):
172          if pod.status.phase == 'Running':
173            selected_pods.append(
174                '%s:%s' % (pod.status.host_ip, self._tf_server_port))
175          else:
176            raise RuntimeError('Pod "%s" is not running; phase: "%s"' %
177                               (pod.metadata.name, pod.status.phase))
178        all_pods.extend(selected_pods)
179      cluster_map[tf_job] = all_pods
180
181    return server_lib.ClusterSpec(cluster_map)
182