1# Copyright 2022 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15Common functionality for running xDS Test Client and Server on Kubernetes.
16"""
17from abc import ABCMeta
18import contextlib
19import dataclasses
20import datetime
21import logging
22import pathlib
23from typing import List, Optional
24
25import mako.template
26import yaml
27
28from framework.helpers import retryers
29import framework.helpers.datetime
30import framework.helpers.highlighter
31import framework.helpers.rand
32from framework.infrastructure import gcp
33from framework.infrastructure import k8s
34from framework.test_app.runners import base_runner
35
36logger = logging.getLogger(__name__)
37
38# Type aliases
39_RunnerError = base_runner.RunnerError
40_HighlighterYaml = framework.helpers.highlighter.HighlighterYaml
41_helper_datetime = framework.helpers.datetime
42_datetime = datetime.datetime
43_timedelta = datetime.timedelta
44
45
46@dataclasses.dataclass(frozen=True)
47class RunHistory:
48    deployment_id: str
49    time_start_requested: _datetime
50    time_start_completed: Optional[_datetime]
51    time_stopped: _datetime
52
53
54class KubernetesBaseRunner(base_runner.BaseRunner, metaclass=ABCMeta):
55    # Pylint wants abstract classes to override abstract methods.
56    # pylint: disable=abstract-method
57
58    TEMPLATE_DIR_NAME = 'kubernetes-manifests'
59    TEMPLATE_DIR_RELATIVE_PATH = f'../../../../{TEMPLATE_DIR_NAME}'
60    ROLE_WORKLOAD_IDENTITY_USER = 'roles/iam.workloadIdentityUser'
61    pod_port_forwarders: List[k8s.PortForwarder]
62    pod_log_collectors: List[k8s.PodLogCollector]
63
64    # Required fields.
65    k8s_namespace: k8s.KubernetesNamespace
66    deployment_name: str
67    image_name: str
68    gcp_project: str
69    gcp_service_account: str
70    gcp_ui_url: str
71
72    # Fields with default values.
73    namespace_template: str = 'namespace.yaml'
74    reuse_namespace: bool = False
75
76    # Mutable state. Describes the current run.
77    namespace: Optional[k8s.V1Namespace] = None
78    deployment: Optional[k8s.V1Deployment] = None
79    deployment_id: Optional[str] = None
80    service_account: Optional[k8s.V1ServiceAccount] = None
81    time_start_requested: Optional[_datetime] = None
82    time_start_completed: Optional[_datetime] = None
83    time_stopped: Optional[_datetime] = None
84    # The history of all runs performed by this runner.
85    run_history: List[RunHistory]
86
87    def __init__(self,
88                 k8s_namespace: k8s.KubernetesNamespace,
89                 *,
90                 deployment_name: str,
91                 image_name: str,
92                 gcp_project: str,
93                 gcp_service_account: str,
94                 gcp_ui_url: str,
95                 namespace_template: Optional[str] = 'namespace.yaml',
96                 reuse_namespace: bool = False):
97        super().__init__()
98
99        # Required fields.
100        self.deployment_name = deployment_name
101        self.image_name = image_name
102        self.gcp_project = gcp_project
103        # Maps GCP service account to Kubernetes service account
104        self.gcp_service_account = gcp_service_account
105        self.gcp_ui_url = gcp_ui_url
106
107        # Kubernetes namespace resources manager.
108        self.k8s_namespace = k8s_namespace
109        if namespace_template:
110            self.namespace_template = namespace_template
111        self.reuse_namespace = reuse_namespace
112
113        # Mutable state
114        self.run_history = []
115        self.pod_port_forwarders = []
116        self.pod_log_collectors = []
117
118        # Highlighter.
119        self._highlighter = _HighlighterYaml()
120
121    def run(self, **kwargs):
122        del kwargs
123        if not self.time_stopped and self.time_start_requested:
124            if self.time_start_completed:
125                raise RuntimeError(
126                    f"Deployment {self.deployment_name}: has already been"
127                    f" started at {self.time_start_completed.isoformat()}")
128            else:
129                raise RuntimeError(
130                    f"Deployment {self.deployment_name}: start has already been"
131                    f" requested at {self.time_start_requested.isoformat()}")
132
133        self._reset_state()
134        self.time_start_requested = _datetime.now()
135
136        self.logs_explorer_link()
137        if self.reuse_namespace:
138            self.namespace = self._reuse_namespace()
139        if not self.namespace:
140            self.namespace = self._create_namespace(
141                self.namespace_template, namespace_name=self.k8s_namespace.name)
142
143    def _start_completed(self):
144        self.time_start_completed = _datetime.now()
145
146    def _stop(self):
147        self.time_stopped = _datetime.now()
148        if self.time_start_requested and self.deployment_id:
149            run_history = RunHistory(
150                deployment_id=self.deployment_id,
151                time_start_requested=self.time_start_requested,
152                time_start_completed=self.time_start_completed,
153                time_stopped=self.time_stopped,
154            )
155            self.run_history.append(run_history)
156
157    def _reset_state(self):
158        """Reset the mutable state of the previous run."""
159        if self.pod_port_forwarders:
160            logger.warning(
161                "Port forwarders weren't cleaned up from the past run: %s",
162                len(self.pod_port_forwarders))
163
164        if self.pod_log_collectors:
165            logger.warning(
166                "Pod log collectors weren't cleaned up from the past run: %s",
167                len(self.pod_log_collectors))
168
169        self.namespace = None
170        self.deployment = None
171        self.deployment_id = None
172        self.service_account = None
173        self.time_start_requested = None
174        self.time_start_completed = None
175        self.time_stopped = None
176        self.pod_port_forwarders = []
177        self.pod_log_collectors = []
178
179    def _cleanup_namespace(self, *, force=False):
180        if (self.namespace and not self.reuse_namespace) or force:
181            self.delete_namespace()
182            self.namespace = None
183
184    def stop_pod_dependencies(self, *, log_drain_sec: int = 0):
185        # Signal to stop logging early so less drain time needed.
186        self.maybe_stop_logging()
187
188        # Stop port forwarders if any.
189        for pod_port_forwarder in self.pod_port_forwarders:
190            pod_port_forwarder.close()
191        self.pod_port_forwarders = []
192
193        for pod_log_collector in self.pod_log_collectors:
194            if log_drain_sec > 0 and not pod_log_collector.drain_event.is_set():
195                logger.info("Draining logs for %s, timeout %i sec",
196                            pod_log_collector.pod_name, log_drain_sec)
197                # The close will happen normally at the next message.
198                pod_log_collector.drain_event.wait(timeout=log_drain_sec)
199            # Note this will be called from the main thread and may cause
200            # a race for the log file. Still, at least it'll flush the buffers.
201            pod_log_collector.flush()
202
203        self.pod_log_collectors = []
204
205    def get_pod_restarts(self, deployment: k8s.V1Deployment) -> int:
206        if not self.k8s_namespace or not deployment:
207            return 0
208        total_restart: int = 0
209        pods: List[k8s.V1Pod] = self.k8s_namespace.list_deployment_pods(
210            deployment)
211        for pod in pods:
212            total_restart += sum(status.restart_count
213                                 for status in pod.status.container_statuses)
214        return total_restart
215
216    @classmethod
217    def _render_template(cls, template_file, **kwargs):
218        template = mako.template.Template(filename=str(template_file))
219        return template.render(**kwargs)
220
221    @classmethod
222    def _manifests_from_yaml_file(cls, yaml_file):
223        with open(yaml_file) as f:
224            with contextlib.closing(yaml.safe_load_all(f)) as yml:
225                for manifest in yml:
226                    yield manifest
227
228    @classmethod
229    def _manifests_from_str(cls, document):
230        with contextlib.closing(yaml.safe_load_all(document)) as yml:
231            for manifest in yml:
232                yield manifest
233
234    @classmethod
235    def _template_file_from_name(cls, template_name):
236        templates_path = (pathlib.Path(__file__).parent /
237                          cls.TEMPLATE_DIR_RELATIVE_PATH)
238        return templates_path.joinpath(template_name).resolve()
239
240    def _create_from_template(self, template_name, **kwargs) -> object:
241        template_file = self._template_file_from_name(template_name)
242        logger.debug("Loading k8s manifest template: %s", template_file)
243
244        yaml_doc = self._render_template(template_file, **kwargs)
245        logger.info("Rendered template %s/%s:\n%s", self.TEMPLATE_DIR_NAME,
246                    template_name, self._highlighter.highlight(yaml_doc))
247
248        manifests = self._manifests_from_str(yaml_doc)
249        manifest = next(manifests)
250        # Error out on multi-document yaml
251        if next(manifests, False):
252            raise _RunnerError('Exactly one document expected in manifest '
253                               f'{template_file}')
254
255        k8s_objects = self.k8s_namespace.create_single_resource(manifest)
256        if len(k8s_objects) != 1:
257            raise _RunnerError('Expected exactly one object must created from '
258                               f'manifest {template_file}')
259
260        logger.info('%s %s created', k8s_objects[0].kind,
261                    k8s_objects[0].metadata.name)
262        return k8s_objects[0]
263
264    def _reuse_deployment(self, deployment_name) -> k8s.V1Deployment:
265        deployment = self.k8s_namespace.get_deployment(deployment_name)
266        # TODO(sergiitk): check if good or must be recreated
267        return deployment
268
269    def _reuse_service(self, service_name) -> k8s.V1Service:
270        service = self.k8s_namespace.get_service(service_name)
271        # TODO(sergiitk): check if good or must be recreated
272        return service
273
274    def _reuse_namespace(self) -> k8s.V1Namespace:
275        return self.k8s_namespace.get()
276
277    def _create_namespace(self, template, **kwargs) -> k8s.V1Namespace:
278        namespace = self._create_from_template(template, **kwargs)
279        if not isinstance(namespace, k8s.V1Namespace):
280            raise _RunnerError('Expected V1Namespace to be created '
281                               f'from manifest {template}')
282        if namespace.metadata.name != kwargs['namespace_name']:
283            raise _RunnerError('V1Namespace created with unexpected name: '
284                               f'{namespace.metadata.name}')
285        logger.debug('V1Namespace %s created at %s',
286                     namespace.metadata.self_link,
287                     namespace.metadata.creation_timestamp)
288        return namespace
289
290    @classmethod
291    def _get_workload_identity_member_name(cls, project, namespace_name,
292                                           service_account_name):
293        """
294        Returns workload identity member name used to authenticate Kubernetes
295        service accounts.
296
297        https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity
298        """
299        return (f'serviceAccount:{project}.svc.id.goog'
300                f'[{namespace_name}/{service_account_name}]')
301
302    def _grant_workload_identity_user(self, *, gcp_iam, gcp_service_account,
303                                      service_account_name):
304        workload_identity_member = self._get_workload_identity_member_name(
305            gcp_iam.project, self.k8s_namespace.name, service_account_name)
306        logger.info('Granting %s to %s for GCP Service Account %s',
307                    self.ROLE_WORKLOAD_IDENTITY_USER, workload_identity_member,
308                    gcp_service_account)
309
310        gcp_iam.add_service_account_iam_policy_binding(
311            gcp_service_account, self.ROLE_WORKLOAD_IDENTITY_USER,
312            workload_identity_member)
313
314    def _revoke_workload_identity_user(self, *, gcp_iam, gcp_service_account,
315                                       service_account_name):
316        workload_identity_member = self._get_workload_identity_member_name(
317            gcp_iam.project, self.k8s_namespace.name, service_account_name)
318        logger.info('Revoking %s from %s for GCP Service Account %s',
319                    self.ROLE_WORKLOAD_IDENTITY_USER, workload_identity_member,
320                    gcp_service_account)
321        try:
322            gcp_iam.remove_service_account_iam_policy_binding(
323                gcp_service_account, self.ROLE_WORKLOAD_IDENTITY_USER,
324                workload_identity_member)
325        except gcp.api.Error as error:
326            logger.warning('Failed  %s from %s for Service Account %s: %r',
327                           self.ROLE_WORKLOAD_IDENTITY_USER,
328                           workload_identity_member, gcp_service_account, error)
329
330    def _create_service_account(self, template,
331                                **kwargs) -> k8s.V1ServiceAccount:
332        resource = self._create_from_template(template, **kwargs)
333        if not isinstance(resource, k8s.V1ServiceAccount):
334            raise _RunnerError('Expected V1ServiceAccount to be created '
335                               f'from manifest {template}')
336        if resource.metadata.name != kwargs['service_account_name']:
337            raise _RunnerError('V1ServiceAccount created with unexpected name: '
338                               f'{resource.metadata.name}')
339        logger.debug('V1ServiceAccount %s created at %s',
340                     resource.metadata.self_link,
341                     resource.metadata.creation_timestamp)
342        return resource
343
344    def _create_deployment(self, template, **kwargs) -> k8s.V1Deployment:
345        # Not making deployment_name an explicit kwarg to be consistent with
346        # the rest of the _create_* methods, which pass kwargs as-is
347        # to _create_from_template(), so that the kwargs dict is unpacked into
348        # template variables and their values.
349        if 'deployment_name' not in kwargs:
350            raise TypeError('Missing required keyword-only argument: '
351                            'deployment_name')
352
353        # Automatically apply random deployment_id to use in the matchLabels
354        # to prevent selecting pods in the same namespace belonging to
355        # a different deployment.
356        if 'deployment_id' not in kwargs:
357            rand_id: str = framework.helpers.rand.rand_string(lowercase=True)
358            # Fun edge case: when rand_string() happen to generate numbers only,
359            # yaml interprets deployment_id label value as an integer,
360            # but k8s expects label values to be strings. Lol. K8s responds
361            # with a barely readable 400 Bad Request error: 'ReadString: expects
362            # \" or n, but found 9, error found in #10 byte of ...|ent_id'.
363            # Prepending deployment name forces deployment_id into a string,
364            # as well as it's just a better description.
365            self.deployment_id = f'{kwargs["deployment_name"]}-{rand_id}'
366            kwargs['deployment_id'] = self.deployment_id
367        else:
368            self.deployment_id = kwargs['deployment_id']
369
370        deployment = self._create_from_template(template, **kwargs)
371        if not isinstance(deployment, k8s.V1Deployment):
372            raise _RunnerError('Expected V1Deployment to be created '
373                               f'from manifest {template}')
374        if deployment.metadata.name != kwargs['deployment_name']:
375            raise _RunnerError('V1Deployment created with unexpected name: '
376                               f'{deployment.metadata.name}')
377        logger.debug('V1Deployment %s created at %s',
378                     deployment.metadata.self_link,
379                     deployment.metadata.creation_timestamp)
380        return deployment
381
382    def _create_service(self, template, **kwargs) -> k8s.V1Service:
383        service = self._create_from_template(template, **kwargs)
384        if not isinstance(service, k8s.V1Service):
385            raise _RunnerError('Expected V1Service to be created '
386                               f'from manifest {template}')
387        if service.metadata.name != kwargs['service_name']:
388            raise _RunnerError('V1Service created with unexpected name: '
389                               f'{service.metadata.name}')
390        logger.debug('V1Service %s created at %s', service.metadata.self_link,
391                     service.metadata.creation_timestamp)
392        return service
393
394    def _delete_deployment(self, name, wait_for_deletion=True):
395        self.stop_pod_dependencies()
396        logger.info('Deleting deployment %s', name)
397        try:
398            self.k8s_namespace.delete_deployment(name)
399        except (retryers.RetryError, k8s.NotFound) as e:
400            logger.info('Deployment %s deletion failed: %s', name, e)
401            return
402
403        if wait_for_deletion:
404            self.k8s_namespace.wait_for_deployment_deleted(name)
405        logger.debug('Deployment %s deleted', name)
406
407    def _delete_service(self, name, wait_for_deletion=True):
408        logger.info('Deleting service %s', name)
409        try:
410            self.k8s_namespace.delete_service(name)
411        except (retryers.RetryError, k8s.NotFound) as e:
412            logger.info('Service %s deletion failed: %s', name, e)
413            return
414
415        if wait_for_deletion:
416            self.k8s_namespace.wait_for_service_deleted(name)
417
418        logger.debug('Service %s deleted', name)
419
420    def _delete_service_account(self, name, wait_for_deletion=True):
421        logger.info('Deleting service account %s', name)
422        try:
423            self.k8s_namespace.delete_service_account(name)
424        except (retryers.RetryError, k8s.NotFound) as e:
425            logger.info('Service account %s deletion failed: %s', name, e)
426            return
427
428        if wait_for_deletion:
429            self.k8s_namespace.wait_for_service_account_deleted(name)
430        logger.debug('Service account %s deleted', name)
431
432    def delete_namespace(self, wait_for_deletion=True):
433        logger.info('Deleting namespace %s', self.k8s_namespace.name)
434        try:
435            self.k8s_namespace.delete()
436        except (retryers.RetryError, k8s.NotFound) as e:
437            logger.info('Namespace %s deletion failed: %s',
438                        self.k8s_namespace.name, e)
439            return
440
441        if wait_for_deletion:
442            self.k8s_namespace.wait_for_namespace_deleted()
443        logger.debug('Namespace %s deleted', self.k8s_namespace.name)
444
445    def _wait_deployment_with_available_replicas(self, name, count=1, **kwargs):
446        logger.info(
447            'Waiting for deployment %s to report %s '
448            'available replica(s)', name, count)
449        self.k8s_namespace.wait_for_deployment_available_replicas(
450            name, count, **kwargs)
451        deployment = self.k8s_namespace.get_deployment(name)
452        logger.info('Deployment %s has %i replicas available',
453                    deployment.metadata.name,
454                    deployment.status.available_replicas)
455
456    def _wait_deployment_pod_count(self,
457                                   deployment: k8s.V1Deployment,
458                                   count: int = 1,
459                                   **kwargs) -> List[str]:
460        logger.info('Waiting for deployment %s to initialize %s pod(s)',
461                    deployment.metadata.name, count)
462        self.k8s_namespace.wait_for_deployment_replica_count(
463            deployment, count, **kwargs)
464        pods = self.k8s_namespace.list_deployment_pods(deployment)
465        pod_names = [pod.metadata.name for pod in pods]
466        logger.info('Deployment %s initialized %i pod(s): %s',
467                    deployment.metadata.name, count, pod_names)
468        # Pods may not  be started yet, just return the names.
469        return pod_names
470
471    def _wait_pod_started(self, name, **kwargs) -> k8s.V1Pod:
472        logger.info('Waiting for pod %s to start', name)
473        self.k8s_namespace.wait_for_pod_started(name, **kwargs)
474        pod = self.k8s_namespace.get_pod(name)
475        logger.info('Pod %s ready, IP: %s', pod.metadata.name,
476                    pod.status.pod_ip)
477        return pod
478
479    def _start_port_forwarding_pod(self, pod: k8s.V1Pod,
480                                   remote_port: int) -> k8s.PortForwarder:
481        logger.info('LOCAL DEV MODE: Enabling port forwarding to %s:%s',
482                    pod.status.pod_ip, remote_port)
483        port_forwarder = self.k8s_namespace.port_forward_pod(pod, remote_port)
484        self.pod_port_forwarders.append(port_forwarder)
485        return port_forwarder
486
487    def _start_logging_pod(self,
488                           pod: k8s.V1Pod,
489                           *,
490                           log_to_stdout: bool = False) -> k8s.PodLogCollector:
491        pod_name = pod.metadata.name
492        logfile_name = f'{self.k8s_namespace.name}_{pod_name}.log'
493        log_path = self.logs_subdir / logfile_name
494        logger.info('Enabling log collection from pod %s to %s', pod_name,
495                    log_path.relative_to(self.logs_subdir.parent.parent))
496        pod_log_collector = self.k8s_namespace.pod_start_logging(
497            pod_name=pod_name,
498            log_path=log_path,
499            log_stop_event=self.log_stop_event,
500            log_to_stdout=log_to_stdout,
501            # Timestamps are enabled because not all language implementations
502            # include them.
503            # TODO(sergiitk): Make this setting language-specific.
504            log_timestamps=True)
505        self.pod_log_collectors.append(pod_log_collector)
506        return pod_log_collector
507
508    def _wait_service_neg(self, name, service_port, **kwargs):
509        logger.info('Waiting for NEG for service %s', name)
510        self.k8s_namespace.wait_for_service_neg(name, **kwargs)
511        neg_name, neg_zones = self.k8s_namespace.get_service_neg(
512            name, service_port)
513        logger.info("Service %s: detected NEG=%s in zones=%s", name, neg_name,
514                    neg_zones)
515
516    def logs_explorer_link(self):
517        """Prints GCP Logs Explorer link to all runs of the deployment."""
518        self._logs_explorer_link(deployment_name=self.deployment_name,
519                                 namespace_name=self.k8s_namespace.name,
520                                 gcp_project=self.gcp_project,
521                                 gcp_ui_url=self.gcp_ui_url)
522
523    def logs_explorer_run_history_links(self):
524        """Prints a separate GCP Logs Explorer link for each run *completed* by
525        the runner.
526
527        This excludes the current run, if it hasn't been completed.
528        """
529        if not self.run_history:
530            logger.info('No completed deployments of %s', self.deployment_name)
531            return
532        for run in self.run_history:
533            self._logs_explorer_link(deployment_name=self.deployment_name,
534                                     namespace_name=self.k8s_namespace.name,
535                                     gcp_project=self.gcp_project,
536                                     gcp_ui_url=self.gcp_ui_url,
537                                     deployment_id=run.deployment_id,
538                                     start_time=run.time_start_requested,
539                                     end_time=run.time_stopped)
540
541    @classmethod
542    def _logs_explorer_link(cls,
543                            *,
544                            deployment_name: str,
545                            namespace_name: str,
546                            gcp_project: str,
547                            gcp_ui_url: str,
548                            deployment_id: Optional[str] = None,
549                            start_time: Optional[_datetime] = None,
550                            end_time: Optional[_datetime] = None):
551        """Output the link to test server/client logs in GCP Logs Explorer."""
552        if not start_time:
553            start_time = _datetime.now()
554        if not end_time:
555            end_time = start_time + _timedelta(minutes=30)
556
557        logs_start = _helper_datetime.iso8601_utc_time(start_time)
558        logs_end = _helper_datetime.iso8601_utc_time(end_time)
559        request = {'timeRange': f'{logs_start}/{logs_end}'}
560        query = {
561            'resource.type': 'k8s_container',
562            'resource.labels.project_id': gcp_project,
563            'resource.labels.container_name': deployment_name,
564            'resource.labels.namespace_name': namespace_name,
565        }
566        if deployment_id:
567            query['labels."k8s-pod/deployment_id"'] = deployment_id
568
569        link = cls._logs_explorer_link_from_params(gcp_ui_url=gcp_ui_url,
570                                                   gcp_project=gcp_project,
571                                                   query=query,
572                                                   request=request)
573        link_to = deployment_id if deployment_id else deployment_name
574        # A whitespace at the end to indicate the end of the url.
575        logger.info("GCP Logs Explorer link to %s:\n%s ", link_to, link)
576
577    @classmethod
578    def _make_namespace_name(cls, resource_prefix: str, resource_suffix: str,
579                             name: str) -> str:
580        """A helper to make consistent test app kubernetes namespace name
581        for given resource prefix and suffix."""
582        parts = [resource_prefix, name]
583        # Avoid trailing dash when the suffix is empty.
584        if resource_suffix:
585            parts.append(resource_suffix)
586        return '-'.join(parts)
587