1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# TODO(sergiitk): to k8s/ package, and get rid of k8s_internal, which is only
15#   added to get around circular dependencies caused by k8s.py clashing with
16#   k8s/__init__.py
17import datetime
18import json
19import logging
20import pathlib
21import threading
22from typing import Any, Callable, List, Optional, Tuple
23
24from kubernetes import client
25from kubernetes import utils
26import kubernetes.config
27import urllib3.exceptions
28import yaml
29
30from framework.helpers import retryers
31import framework.helpers.highlighter
32from framework.infrastructure.k8s_internal import k8s_log_collector
33from framework.infrastructure.k8s_internal import k8s_port_forwarder
34
35logger = logging.getLogger(__name__)
36
37# Type aliases
38_HighlighterYaml = framework.helpers.highlighter.HighlighterYaml
39PodLogCollector = k8s_log_collector.PodLogCollector
40PortForwarder = k8s_port_forwarder.PortForwarder
41ApiClient = client.ApiClient
42V1Deployment = client.V1Deployment
43V1ServiceAccount = client.V1ServiceAccount
44V1Pod = client.V1Pod
45V1PodList = client.V1PodList
46V1Service = client.V1Service
47V1Namespace = client.V1Namespace
48
49_timedelta = datetime.timedelta
50_ApiException = client.ApiException
51_FailToCreateError = utils.FailToCreateError
52
53_RETRY_ON_EXCEPTIONS = (urllib3.exceptions.HTTPError, _ApiException,
54                        _FailToCreateError)
55
56
57def _server_restart_retryer() -> retryers.Retrying:
58    return retryers.exponential_retryer_with_timeout(
59        retry_on_exceptions=_RETRY_ON_EXCEPTIONS,
60        wait_min=_timedelta(seconds=1),
61        wait_max=_timedelta(seconds=10),
62        timeout=_timedelta(minutes=3))
63
64
65def _too_many_requests_retryer() -> retryers.Retrying:
66    return retryers.exponential_retryer_with_timeout(
67        retry_on_exceptions=_RETRY_ON_EXCEPTIONS,
68        wait_min=_timedelta(seconds=10),
69        wait_max=_timedelta(seconds=30),
70        timeout=_timedelta(minutes=3))
71
72
73def _quick_recovery_retryer() -> retryers.Retrying:
74    return retryers.constant_retryer(wait_fixed=_timedelta(seconds=1),
75                                     attempts=3,
76                                     retry_on_exceptions=_RETRY_ON_EXCEPTIONS)
77
78
79def label_dict_to_selector(labels: dict) -> str:
80    return ','.join(f'{k}=={v}' for k, v in labels.items())
81
82
83class NotFound(Exception):
84    """Indicates the resource is not found on the API server."""
85
86
87class KubernetesApiManager:
88    _client: ApiClient
89    context: str
90    apps: client.AppsV1Api
91    core: client.CoreV1Api
92    _apis: set
93
94    def __init__(self, context: str):
95        self.context = context
96        self._client = self._new_client_from_context(context)
97        self.apps = client.AppsV1Api(self.client)
98        self.core = client.CoreV1Api(self.client)
99        self._apis = {self.apps, self.core}
100
101    @property
102    def client(self) -> ApiClient:
103        return self._client
104
105    def close(self):
106        self.client.close()
107
108    def reload(self):
109        self.close()
110        self._client = self._new_client_from_context(self.context)
111        # Update default configuration so that modules that initialize
112        # ApiClient implicitly (e.g. kubernetes.watch.Watch) get the updates.
113        client.Configuration.set_default(self._client.configuration)
114        for api in self._apis:
115            api.api_client = self._client
116
117    @staticmethod
118    def _new_client_from_context(context: str) -> ApiClient:
119        client_instance = kubernetes.config.new_client_from_config(
120            context=context)
121        logger.info('Using kubernetes context "%s", active host: %s', context,
122                    client_instance.configuration.host)
123        # TODO(sergiitk): fine-tune if we see the total wait unreasonably long.
124        client_instance.configuration.retries = 10
125        return client_instance
126
127
128class KubernetesNamespace:  # pylint: disable=too-many-public-methods
129    _highlighter: framework.helpers.highlighter.Highlighter
130    _api: KubernetesApiManager
131    _name: str
132
133    NEG_STATUS_META = 'cloud.google.com/neg-status'
134    DELETE_GRACE_PERIOD_SEC: int = 5
135    WAIT_SHORT_TIMEOUT_SEC: int = 60
136    WAIT_SHORT_SLEEP_SEC: int = 1
137    WAIT_MEDIUM_TIMEOUT_SEC: int = 5 * 60
138    WAIT_MEDIUM_SLEEP_SEC: int = 10
139    WAIT_LONG_TIMEOUT_SEC: int = 10 * 60
140    WAIT_LONG_SLEEP_SEC: int = 30
141    WAIT_POD_START_TIMEOUT_SEC: int = 3 * 60
142
143    def __init__(self, api: KubernetesApiManager, name: str):
144        self._api = api
145        self._name = name
146        self._highlighter = _HighlighterYaml()
147
148    @property
149    def name(self):
150        return self._name
151
152    def _refresh_auth(self):
153        logger.info('Reloading k8s api client to refresh the auth.')
154        self._api.reload()
155
156    def _apply_manifest(self, manifest):
157        return utils.create_from_dict(self._api.client,
158                                      manifest,
159                                      namespace=self.name)
160
161    def _get_resource(self, method: Callable[[Any], object], *args, **kwargs):
162        try:
163            return self._execute(method, *args, **kwargs)
164        except NotFound:
165            # Instead of trowing an error when a resource doesn't exist,
166            # just return None.
167            return None
168
169    def _execute(self, method: Callable[[Any], object], *args, **kwargs):
170        # Note: Intentionally leaving return type as unspecified to not confuse
171        # pytype for methods that delegate calls to this wrapper.
172        try:
173            return method(*args, **kwargs)
174        except _RETRY_ON_EXCEPTIONS as err:
175            retryer = self._handle_exception(err)
176            if retryer is not None:
177                return retryer(method, *args, **kwargs)
178            raise
179
180    def _handle_exception(self, err: Exception) -> Optional[retryers.Retrying]:
181        # TODO(sergiitk): replace returns with match/case when we use to py3.10.
182        # pylint: disable=too-many-return-statements
183
184        # Unwrap MaxRetryError.
185        if isinstance(err, urllib3.exceptions.MaxRetryError):
186            return self._handle_exception(err.reason) if err.reason else None
187
188        # We consider all `NewConnectionError`s as caused by a k8s
189        # API server restart. `NewConnectionError`s we've seen:
190        #   - [Errno 110] Connection timed out
191        #   - [Errno 111] Connection refused
192        if isinstance(err, urllib3.exceptions.NewConnectionError):
193            return _server_restart_retryer()
194
195        # We consider all `ProtocolError`s with "Connection aborted" message
196        # as caused by a k8s API server restart.
197        # `ProtocolError`s we've seen:
198        #   - RemoteDisconnected('Remote end closed connection
199        #     without response')
200        #   - ConnectionResetError(104, 'Connection reset by peer')
201        if isinstance(err, urllib3.exceptions.ProtocolError):
202            if 'connection aborted' in str(err).lower():
203                return _server_restart_retryer()
204            else:
205                # To cover other cases we didn't account for, and haven't
206                # seen in the wild, f.e. "Connection broken"
207                return _quick_recovery_retryer()
208
209        # ApiException means the server has received our request and responded
210        # with an error we can parse (except a few corner cases, f.e. SSLError).
211        if isinstance(err, _ApiException):
212            return self._handle_api_exception(err)
213
214        # Unwrap FailToCreateError.
215        if isinstance(err, _FailToCreateError):
216            # We're always sending a single document, so we expect
217            # a single wrapped exception in return.
218            if len(err.api_exceptions) == 1:
219                return self._handle_exception(err.api_exceptions[0])
220
221        return None
222
223    def _handle_api_exception(
224            self, err: _ApiException) -> Optional[retryers.Retrying]:
225        # TODO(sergiitk): replace returns with match/case when we use to py3.10.
226        # pylint: disable=too-many-return-statements
227
228        # TODO(sergiitk): can I chain the retryers?
229        logger.debug(
230            'Handling k8s.ApiException: status=%s reason=%s body=%s headers=%s',
231            err.status, err.reason, err.body, err.headers)
232
233        code: int = err.status
234        body = err.body.lower() if err.body else ''
235
236        # 401 Unauthorized: token might be expired, attempt auth refresh.
237        if code == 401:
238            self._refresh_auth()
239            return _quick_recovery_retryer()
240
241        # 404 Not Found. Make it easier for the caller to handle 404s.
242        if code == 404:
243            raise NotFound('Kubernetes API returned 404 Not Found: '
244                           f'{self._status_message_or_body(body)}') from err
245
246        # 409 Conflict
247        # "Operation cannot be fulfilled on resourcequotas "foo": the object
248        # has been modified; please apply your changes to the latest version
249        # and try again".
250        # See https://github.com/kubernetes/kubernetes/issues/67761
251        if code == 409:
252            return _quick_recovery_retryer()
253
254        # 429 Too Many Requests: "Too many requests, please try again later"
255        if code == 429:
256            return _too_many_requests_retryer()
257
258        # 500 Internal Server Error
259        if code == 500:
260            # Observed when using `kubectl proxy`.
261            # "dial tcp 127.0.0.1:8080: connect: connection refused"
262            if 'connection refused' in body:
263                return _server_restart_retryer()
264
265            # Known 500 errors that should be treated as 429:
266            # - Internal Server Error: "/api/v1/namespaces": the server has
267            #   received too many requests and has asked us
268            #   to try again later
269            # - Internal Server Error: "/api/v1/namespaces/foo/services":
270            #   the server is currently unable to handle the request
271            if ('too many requests' in body or
272                    'currently unable to handle the request' in body):
273                return _too_many_requests_retryer()
274
275            # In other cases, just retry a few times in case the server
276            # resumes normal operation.
277            return _quick_recovery_retryer()
278
279        # 504 Gateway Timeout:
280        # "Timeout: request did not complete within the allotted timeout"
281        if code == 504:
282            return _server_restart_retryer()
283
284        return None
285
286    @classmethod
287    def _status_message_or_body(cls, body: str) -> str:
288        try:
289            return str(json.loads(body)['message'])
290        except (KeyError, ValueError):
291            return body
292
293    def create_single_resource(self, manifest):
294        return self._execute(self._apply_manifest, manifest)
295
296    def get_service(self, name) -> V1Service:
297        return self._get_resource(self._api.core.read_namespaced_service, name,
298                                  self.name)
299
300    def get_service_account(self, name) -> V1Service:
301        return self._get_resource(
302            self._api.core.read_namespaced_service_account, name, self.name)
303
304    def delete_service(self,
305                       name,
306                       grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
307        self._execute(self._api.core.delete_namespaced_service,
308                      name=name,
309                      namespace=self.name,
310                      body=client.V1DeleteOptions(
311                          propagation_policy='Foreground',
312                          grace_period_seconds=grace_period_seconds))
313
314    def delete_service_account(self,
315                               name,
316                               grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
317        self._execute(self._api.core.delete_namespaced_service_account,
318                      name=name,
319                      namespace=self.name,
320                      body=client.V1DeleteOptions(
321                          propagation_policy='Foreground',
322                          grace_period_seconds=grace_period_seconds))
323
324    def get(self) -> V1Namespace:
325        return self._get_resource(self._api.core.read_namespace, self.name)
326
327    def delete(self, grace_period_seconds=DELETE_GRACE_PERIOD_SEC):
328        self._execute(self._api.core.delete_namespace,
329                      name=self.name,
330                      body=client.V1DeleteOptions(
331                          propagation_policy='Foreground',
332                          grace_period_seconds=grace_period_seconds))
333
334    def wait_for_service_deleted(self,
335                                 name: str,
336                                 timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC,
337                                 wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None:
338        retryer = retryers.constant_retryer(
339            wait_fixed=_timedelta(seconds=wait_sec),
340            timeout=_timedelta(seconds=timeout_sec),
341            check_result=lambda service: service is None)
342        retryer(self.get_service, name)
343
344    def wait_for_service_account_deleted(
345            self,
346            name: str,
347            timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC,
348            wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None:
349        retryer = retryers.constant_retryer(
350            wait_fixed=_timedelta(seconds=wait_sec),
351            timeout=_timedelta(seconds=timeout_sec),
352            check_result=lambda service_account: service_account is None)
353        retryer(self.get_service_account, name)
354
355    def wait_for_namespace_deleted(self,
356                                   timeout_sec: int = WAIT_LONG_TIMEOUT_SEC,
357                                   wait_sec: int = WAIT_LONG_SLEEP_SEC) -> None:
358        retryer = retryers.constant_retryer(
359            wait_fixed=_timedelta(seconds=wait_sec),
360            timeout=_timedelta(seconds=timeout_sec),
361            check_result=lambda namespace: namespace is None)
362        retryer(self.get)
363
364    def wait_for_service_neg(self,
365                             name: str,
366                             timeout_sec: int = WAIT_SHORT_TIMEOUT_SEC,
367                             wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None:
368        timeout = _timedelta(seconds=timeout_sec)
369        retryer = retryers.constant_retryer(
370            wait_fixed=_timedelta(seconds=wait_sec),
371            timeout=timeout,
372            check_result=self._check_service_neg_annotation)
373        try:
374            retryer(self.get_service, name)
375        except retryers.RetryError as e:
376            logger.error(
377                'Timeout %s (h:mm:ss) waiting for service %s to report NEG '
378                'status. Last service status:\n%s', timeout, name,
379                self._pretty_format_status(e.result()))
380            raise
381
382    def get_service_neg(self, service_name: str,
383                        service_port: int) -> Tuple[str, List[str]]:
384        service = self.get_service(service_name)
385        neg_info: dict = json.loads(
386            service.metadata.annotations[self.NEG_STATUS_META])
387        neg_name: str = neg_info['network_endpoint_groups'][str(service_port)]
388        neg_zones: List[str] = neg_info['zones']
389        return neg_name, neg_zones
390
391    def get_deployment(self, name) -> V1Deployment:
392        return self._get_resource(self._api.apps.read_namespaced_deployment,
393                                  name, self.name)
394
395    def delete_deployment(
396            self,
397            name: str,
398            grace_period_seconds: int = DELETE_GRACE_PERIOD_SEC) -> None:
399        self._execute(self._api.apps.delete_namespaced_deployment,
400                      name=name,
401                      namespace=self.name,
402                      body=client.V1DeleteOptions(
403                          propagation_policy='Foreground',
404                          grace_period_seconds=grace_period_seconds))
405
406    def list_deployment_pods(self, deployment: V1Deployment) -> List[V1Pod]:
407        # V1LabelSelector.match_expressions not supported at the moment
408        return self.list_pods_with_labels(deployment.spec.selector.match_labels)
409
410    def wait_for_deployment_available_replicas(
411            self,
412            name: str,
413            count: int = 1,
414            timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC,
415            wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None:
416        timeout = _timedelta(seconds=timeout_sec)
417        retryer = retryers.constant_retryer(
418            wait_fixed=_timedelta(seconds=wait_sec),
419            timeout=timeout,
420            check_result=lambda depl: self._replicas_available(depl, count))
421        try:
422            retryer(self.get_deployment, name)
423        except retryers.RetryError as e:
424            logger.error(
425                'Timeout %s (h:mm:ss) waiting for deployment %s to report %i '
426                'replicas available. Last status:\n%s', timeout, name, count,
427                self._pretty_format_status(e.result()))
428            raise
429
430    def wait_for_deployment_replica_count(
431            self,
432            deployment: V1Deployment,
433            count: int = 1,
434            *,
435            timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC,
436            wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None:
437        timeout = _timedelta(seconds=timeout_sec)
438        retryer = retryers.constant_retryer(
439            wait_fixed=_timedelta(seconds=wait_sec),
440            timeout=timeout,
441            check_result=lambda pods: len(pods) == count)
442        try:
443            retryer(self.list_deployment_pods, deployment)
444        except retryers.RetryError as e:
445            result = e.result(default=[])
446            logger.error(
447                'Timeout %s (h:mm:ss) waiting for pod count %i, got: %i. '
448                'Pod statuses:\n%s', timeout, count, len(result),
449                self._pretty_format_statuses(result))
450            raise
451
452    def wait_for_deployment_deleted(
453            self,
454            deployment_name: str,
455            timeout_sec: int = WAIT_MEDIUM_TIMEOUT_SEC,
456            wait_sec: int = WAIT_MEDIUM_SLEEP_SEC) -> None:
457        retryer = retryers.constant_retryer(
458            wait_fixed=_timedelta(seconds=wait_sec),
459            timeout=_timedelta(seconds=timeout_sec),
460            check_result=lambda deployment: deployment is None)
461        retryer(self.get_deployment, deployment_name)
462
463    def list_pods_with_labels(self, labels: dict) -> List[V1Pod]:
464        pod_list: V1PodList = self._execute(
465            self._api.core.list_namespaced_pod,
466            self.name,
467            label_selector=label_dict_to_selector(labels))
468        return pod_list.items
469
470    def get_pod(self, name: str) -> V1Pod:
471        return self._get_resource(self._api.core.read_namespaced_pod, name,
472                                  self.name)
473
474    def wait_for_pod_started(self,
475                             pod_name: str,
476                             timeout_sec: int = WAIT_POD_START_TIMEOUT_SEC,
477                             wait_sec: int = WAIT_SHORT_SLEEP_SEC) -> None:
478        timeout = _timedelta(seconds=timeout_sec)
479        retryer = retryers.constant_retryer(
480            wait_fixed=_timedelta(seconds=wait_sec),
481            timeout=timeout,
482            check_result=self._pod_started)
483        try:
484            retryer(self.get_pod, pod_name)
485        except retryers.RetryError as e:
486            logger.error(
487                'Timeout %s (h:mm:ss) waiting for pod %s to start. '
488                'Pod status:\n%s', timeout, pod_name,
489                self._pretty_format_status(e.result()))
490            raise
491
492    def port_forward_pod(
493        self,
494        pod: V1Pod,
495        remote_port: int,
496        local_port: Optional[int] = None,
497        local_address: Optional[str] = None,
498    ) -> k8s_port_forwarder.PortForwarder:
499        pf = k8s_port_forwarder.PortForwarder(self._api.context, self.name,
500                                              f"pod/{pod.metadata.name}",
501                                              remote_port, local_port,
502                                              local_address)
503        pf.connect()
504        return pf
505
506    def pod_start_logging(self,
507                          *,
508                          pod_name: str,
509                          log_path: pathlib.Path,
510                          log_stop_event: threading.Event,
511                          log_to_stdout: bool = False,
512                          log_timestamps: bool = False) -> PodLogCollector:
513        pod_log_collector = PodLogCollector(
514            pod_name=pod_name,
515            namespace_name=self.name,
516            read_pod_log_fn=self._api.core.read_namespaced_pod_log,
517            stop_event=log_stop_event,
518            log_path=log_path,
519            log_to_stdout=log_to_stdout,
520            log_timestamps=log_timestamps)
521        pod_log_collector.start()
522        return pod_log_collector
523
524    def _pretty_format_statuses(self,
525                                k8s_objects: List[Optional[object]]) -> str:
526        return '\n'.join(
527            self._pretty_format_status(k8s_object)
528            for k8s_object in k8s_objects)
529
530    def _pretty_format_status(self, k8s_object: Optional[object]) -> str:
531        if k8s_object is None:
532            return 'No data'
533
534        # Parse the name if present.
535        if hasattr(k8s_object, 'metadata') and hasattr(k8s_object.metadata,
536                                                       'name'):
537            name = k8s_object.metadata.name
538        else:
539            name = 'Can\'t parse resource name'
540
541        # Pretty-print the status if present.
542        if hasattr(k8s_object, 'status'):
543            try:
544                status = self._pretty_format(k8s_object.status.to_dict())
545            except Exception as e:  # pylint: disable=broad-except
546                # Catching all exceptions because not printing the status
547                # isn't as important as the system under test.
548                status = f'Can\'t parse resource status: {e}'
549        else:
550            status = 'Can\'t parse resource status'
551
552        # Return the name of k8s object, and its pretty-printed status.
553        return f'{name}:\n{status}\n'
554
555    def _pretty_format(self, data: dict) -> str:
556        """Return a string with pretty-printed yaml data from a python dict."""
557        yaml_out: str = yaml.dump(data, explicit_start=True, explicit_end=True)
558        return self._highlighter.highlight(yaml_out)
559
560    @classmethod
561    def _check_service_neg_annotation(cls,
562                                      service: Optional[V1Service]) -> bool:
563        return (isinstance(service, V1Service) and
564                cls.NEG_STATUS_META in service.metadata.annotations)
565
566    @classmethod
567    def _pod_started(cls, pod: V1Pod) -> bool:
568        return (isinstance(pod, V1Pod) and
569                pod.status.phase not in ('Pending', 'Unknown'))
570
571    @classmethod
572    def _replicas_available(cls, deployment: V1Deployment, count: int) -> bool:
573        return (isinstance(deployment, V1Deployment) and
574                deployment.status.available_replicas is not None and
575                deployment.status.available_replicas >= count)
576