1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import abc
15import contextlib
16import datetime
17import enum
18import hashlib
19import logging
20import re
21import signal
22import time
23from types import FrameType
24from typing import Any, Callable, List, Optional, Tuple, Union
25
26from absl import flags
27from absl.testing import absltest
28from google.protobuf import json_format
29import grpc
30
31from framework import xds_flags
32from framework import xds_k8s_flags
33from framework import xds_url_map_testcase
34from framework.helpers import grpc as helpers_grpc
35from framework.helpers import rand as helpers_rand
36from framework.helpers import retryers
37from framework.helpers import skips
38from framework.infrastructure import gcp
39from framework.infrastructure import k8s
40from framework.infrastructure import traffic_director
41from framework.rpc import grpc_channelz
42from framework.rpc import grpc_csds
43from framework.rpc import grpc_testing
44from framework.test_app import client_app
45from framework.test_app import server_app
46from framework.test_app.runners.k8s import k8s_xds_client_runner
47from framework.test_app.runners.k8s import k8s_xds_server_runner
48
49logger = logging.getLogger(__name__)
50# TODO(yashkt): We will no longer need this flag once Core exposes local certs
51# from channelz
52_CHECK_LOCAL_CERTS = flags.DEFINE_bool(
53    "check_local_certs",
54    default=True,
55    help="Security Tests also check the value of local certs")
56flags.adopt_module_key_flags(xds_flags)
57flags.adopt_module_key_flags(xds_k8s_flags)
58
59# Type aliases
60TrafficDirectorManager = traffic_director.TrafficDirectorManager
61TrafficDirectorAppNetManager = traffic_director.TrafficDirectorAppNetManager
62TrafficDirectorSecureManager = traffic_director.TrafficDirectorSecureManager
63XdsTestServer = server_app.XdsTestServer
64XdsTestClient = client_app.XdsTestClient
65KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner
66KubernetesClientRunner = k8s_xds_client_runner.KubernetesClientRunner
67LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse
68_ChannelState = grpc_channelz.ChannelState
69_timedelta = datetime.timedelta
70ClientConfig = grpc_csds.ClientConfig
71# pylint complains about signal.Signals for some reason.
72_SignalNum = Union[int, signal.Signals]  # pylint: disable=no-member
73_SignalHandler = Callable[[_SignalNum, Optional[FrameType]], Any]
74
75_TD_CONFIG_MAX_WAIT_SEC = 600
76
77
78class TdPropagationRetryableError(Exception):
79    """Indicates that TD config hasn't propagated yet, and it's safe to retry"""
80
81
82class XdsKubernetesBaseTestCase(absltest.TestCase):
83    lang_spec: skips.TestConfig
84    client_namespace: str
85    client_runner: KubernetesClientRunner
86    ensure_firewall: bool
87    force_cleanup: bool
88    gcp_api_manager: gcp.api.GcpApiManager
89    gcp_service_account: Optional[str]
90    k8s_api_manager: k8s.KubernetesApiManager
91    secondary_k8s_api_manager: k8s.KubernetesApiManager
92    network: str
93    project: str
94    resource_prefix: str
95    resource_suffix: str = ''
96    # Whether to randomize resources names for each test by appending a
97    # unique suffix.
98    resource_suffix_randomize: bool = True
99    server_maintenance_port: Optional[int]
100    server_namespace: str
101    server_runner: KubernetesServerRunner
102    server_xds_host: str
103    server_xds_port: int
104    td: TrafficDirectorManager
105    td_bootstrap_image: str
106    _prev_sigint_handler: Optional[_SignalHandler] = None
107    _handling_sigint: bool = False
108
109    @staticmethod
110    def is_supported(config: skips.TestConfig) -> bool:
111        """Overridden by the test class to decide if the config is supported.
112
113        Returns:
114          A bool indicates if the given config is supported.
115        """
116        del config
117        return True
118
119    @classmethod
120    def setUpClass(cls):
121        """Hook method for setting up class fixture before running tests in
122        the class.
123        """
124        logger.info('----- Testing %s -----', cls.__name__)
125        logger.info('Logs timezone: %s', time.localtime().tm_zone)
126
127        # Raises unittest.SkipTest if given client/server/version does not
128        # support current test case.
129        cls.lang_spec = skips.evaluate_test_config(cls.is_supported)
130
131        # Must be called before KubernetesApiManager or GcpApiManager init.
132        xds_flags.set_socket_default_timeout_from_flag()
133
134        # GCP
135        cls.project = xds_flags.PROJECT.value
136        cls.network = xds_flags.NETWORK.value
137        cls.gcp_service_account = xds_k8s_flags.GCP_SERVICE_ACCOUNT.value
138        cls.td_bootstrap_image = xds_k8s_flags.TD_BOOTSTRAP_IMAGE.value
139        cls.xds_server_uri = xds_flags.XDS_SERVER_URI.value
140        cls.ensure_firewall = xds_flags.ENSURE_FIREWALL.value
141        cls.firewall_allowed_ports = xds_flags.FIREWALL_ALLOWED_PORTS.value
142        cls.compute_api_version = xds_flags.COMPUTE_API_VERSION.value
143
144        # Resource names.
145        cls.resource_prefix = xds_flags.RESOURCE_PREFIX.value
146        if xds_flags.RESOURCE_SUFFIX.value is not None:
147            cls.resource_suffix_randomize = False
148            cls.resource_suffix = xds_flags.RESOURCE_SUFFIX.value
149
150        # Test server
151        cls.server_image = xds_k8s_flags.SERVER_IMAGE.value
152        cls.server_name = xds_flags.SERVER_NAME.value
153        cls.server_port = xds_flags.SERVER_PORT.value
154        cls.server_maintenance_port = xds_flags.SERVER_MAINTENANCE_PORT.value
155        cls.server_xds_host = xds_flags.SERVER_NAME.value
156        cls.server_xds_port = xds_flags.SERVER_XDS_PORT.value
157
158        # Test client
159        cls.client_image = xds_k8s_flags.CLIENT_IMAGE.value
160        cls.client_name = xds_flags.CLIENT_NAME.value
161        cls.client_port = xds_flags.CLIENT_PORT.value
162
163        # Test suite settings
164        cls.force_cleanup = xds_flags.FORCE_CLEANUP.value
165        cls.debug_use_port_forwarding = \
166            xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value
167        cls.enable_workload_identity = \
168            xds_k8s_flags.ENABLE_WORKLOAD_IDENTITY.value
169        cls.check_local_certs = _CHECK_LOCAL_CERTS.value
170
171        # Resource managers
172        cls.k8s_api_manager = k8s.KubernetesApiManager(
173            xds_k8s_flags.KUBE_CONTEXT.value)
174        cls.secondary_k8s_api_manager = k8s.KubernetesApiManager(
175            xds_k8s_flags.SECONDARY_KUBE_CONTEXT.value)
176        cls.gcp_api_manager = gcp.api.GcpApiManager()
177
178    @classmethod
179    def tearDownClass(cls):
180        cls.k8s_api_manager.close()
181        cls.secondary_k8s_api_manager.close()
182        cls.gcp_api_manager.close()
183
184    def setUp(self):
185        self._prev_sigint_handler = signal.signal(signal.SIGINT,
186                                                  self.handle_sigint)
187
188    def handle_sigint(self, signalnum: _SignalNum,
189                      frame: Optional[FrameType]) -> None:
190        logger.info('Caught Ctrl+C, cleaning up...')
191        self._handling_sigint = True
192        # Force resource cleanup by their name. Addresses the case where ctrl-c
193        # is pressed while waiting for the resource creation.
194        self.force_cleanup = True
195        self.tearDown()
196        self.tearDownClass()
197        self._handling_sigint = False
198        if self._prev_sigint_handler is not None:
199            signal.signal(signal.SIGINT, self._prev_sigint_handler)
200        raise KeyboardInterrupt
201
202    @contextlib.contextmanager
203    def subTest(self, msg, **params):  # noqa pylint: disable=signature-differs
204        logger.info('--- Starting subTest %s.%s ---', self.id(), msg)
205        try:
206            yield super().subTest(msg, **params)
207        finally:
208            if not self._handling_sigint:
209                logger.info('--- Finished subTest %s.%s ---', self.id(), msg)
210
211    def setupTrafficDirectorGrpc(self):
212        self.td.setup_for_grpc(self.server_xds_host,
213                               self.server_xds_port,
214                               health_check_port=self.server_maintenance_port)
215
216    def setupServerBackends(self,
217                            *,
218                            wait_for_healthy_status=True,
219                            server_runner=None,
220                            max_rate_per_endpoint: Optional[int] = None):
221        if server_runner is None:
222            server_runner = self.server_runner
223        # Load Backends
224        neg_name, neg_zones = server_runner.k8s_namespace.get_service_neg(
225            server_runner.service_name, self.server_port)
226
227        # Add backends to the Backend Service
228        self.td.backend_service_add_neg_backends(
229            neg_name, neg_zones, max_rate_per_endpoint=max_rate_per_endpoint)
230        if wait_for_healthy_status:
231            self.td.wait_for_backends_healthy_status()
232
233    def removeServerBackends(self, *, server_runner=None):
234        if server_runner is None:
235            server_runner = self.server_runner
236        # Load Backends
237        neg_name, neg_zones = server_runner.k8s_namespace.get_service_neg(
238            server_runner.service_name, self.server_port)
239
240        # Remove backends from the Backend Service
241        self.td.backend_service_remove_neg_backends(neg_name, neg_zones)
242
243    def assertSuccessfulRpcs(self,
244                             test_client: XdsTestClient,
245                             num_rpcs: int = 100):
246        lb_stats = self.getClientRpcStats(test_client, num_rpcs)
247        self.assertAllBackendsReceivedRpcs(lb_stats)
248        failed = int(lb_stats.num_failures)
249        self.assertLessEqual(
250            failed,
251            0,
252            msg=f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed')
253
254    @staticmethod
255    def diffAccumulatedStatsPerMethod(
256        before: grpc_testing.LoadBalancerAccumulatedStatsResponse,
257        after: grpc_testing.LoadBalancerAccumulatedStatsResponse
258    ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse:
259        """Only diffs stats_per_method, as the other fields are deprecated."""
260        diff = grpc_testing.LoadBalancerAccumulatedStatsResponse()
261        for method, method_stats in after.stats_per_method.items():
262            for status, count in method_stats.result.items():
263                count -= before.stats_per_method[method].result[status]
264                if count < 0:
265                    raise AssertionError("Diff of count shouldn't be negative")
266                if count > 0:
267                    diff.stats_per_method[method].result[status] = count
268        return diff
269
270    def assertRpcStatusCodes(self,
271                             test_client: XdsTestClient,
272                             *,
273                             expected_status: grpc.StatusCode,
274                             duration: _timedelta,
275                             method: str,
276                             stray_rpc_limit: int = 0) -> None:
277        """Assert all RPCs for a method are completing with a certain status."""
278        # Sending with pre-set QPS for a period of time
279        before_stats = test_client.get_load_balancer_accumulated_stats()
280        response_type = 'LoadBalancerAccumulatedStatsResponse'
281        logging.info('Received %s from test client %s: before:\n%s',
282                     response_type, test_client.hostname, before_stats)
283        time.sleep(duration.total_seconds())
284        after_stats = test_client.get_load_balancer_accumulated_stats()
285        logging.info('Received %s from test client %s: after:\n%s',
286                     response_type, test_client.hostname, after_stats)
287
288        diff_stats = self.diffAccumulatedStatsPerMethod(before_stats,
289                                                        after_stats)
290        stats = diff_stats.stats_per_method[method]
291        for found_status_int, count in stats.result.items():
292            found_status = helpers_grpc.status_from_int(found_status_int)
293            if found_status != expected_status and count > stray_rpc_limit:
294                self.fail(f"Expected only status"
295                          f" {helpers_grpc.status_pretty(expected_status)},"
296                          " but found status"
297                          f" {helpers_grpc.status_pretty(found_status)}"
298                          f" for method {method}:\n{diff_stats}")
299
300        expected_status_int: int = expected_status.value[0]
301        self.assertGreater(
302            stats.result[expected_status_int],
303            0,
304            msg=("Expected non-zero RPCs with status"
305                 f" {helpers_grpc.status_pretty(expected_status)}"
306                 f" for method {method}, got:\n{diff_stats}"))
307
308    def assertRpcsEventuallyGoToGivenServers(self,
309                                             test_client: XdsTestClient,
310                                             servers: List[XdsTestServer],
311                                             num_rpcs: int = 100):
312        retryer = retryers.constant_retryer(
313            wait_fixed=datetime.timedelta(seconds=1),
314            timeout=datetime.timedelta(seconds=_TD_CONFIG_MAX_WAIT_SEC),
315            log_level=logging.INFO)
316        try:
317            retryer(self._assertRpcsEventuallyGoToGivenServers, test_client,
318                    servers, num_rpcs)
319        except retryers.RetryError as retry_error:
320            logger.exception(
321                'Rpcs did not go to expected servers before timeout %s',
322                _TD_CONFIG_MAX_WAIT_SEC)
323            raise retry_error
324
325    def _assertRpcsEventuallyGoToGivenServers(self, test_client: XdsTestClient,
326                                              servers: List[XdsTestServer],
327                                              num_rpcs: int):
328        server_hostnames = [server.hostname for server in servers]
329        logger.info('Verifying RPCs go to servers %s', server_hostnames)
330        lb_stats = self.getClientRpcStats(test_client, num_rpcs)
331        failed = int(lb_stats.num_failures)
332        self.assertLessEqual(
333            failed,
334            0,
335            msg=f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed')
336        for server_hostname in server_hostnames:
337            self.assertIn(server_hostname, lb_stats.rpcs_by_peer,
338                          f'Server {server_hostname} did not receive RPCs')
339        for server_hostname in lb_stats.rpcs_by_peer.keys():
340            self.assertIn(server_hostname, server_hostnames,
341                          f'Unexpected server {server_hostname} received RPCs')
342
343    def assertXdsConfigExists(self, test_client: XdsTestClient):
344        config = test_client.csds.fetch_client_status(log_level=logging.INFO)
345        self.assertIsNotNone(config)
346        seen = set()
347        want = frozenset([
348            'listener_config',
349            'cluster_config',
350            'route_config',
351            'endpoint_config',
352        ])
353        for xds_config in config.xds_config:
354            seen.add(xds_config.WhichOneof('per_xds_config'))
355        for generic_xds_config in config.generic_xds_configs:
356            if re.search(r'\.Listener$', generic_xds_config.type_url):
357                seen.add('listener_config')
358            elif re.search(r'\.RouteConfiguration$',
359                           generic_xds_config.type_url):
360                seen.add('route_config')
361            elif re.search(r'\.Cluster$', generic_xds_config.type_url):
362                seen.add('cluster_config')
363            elif re.search(r'\.ClusterLoadAssignment$',
364                           generic_xds_config.type_url):
365                seen.add('endpoint_config')
366        logger.debug('Received xDS config dump: %s',
367                     json_format.MessageToJson(config, indent=2))
368        self.assertSameElements(want, seen)
369
370    def assertRouteConfigUpdateTrafficHandoff(
371            self, test_client: XdsTestClient,
372            previous_route_config_version: str, retry_wait_second: int,
373            timeout_second: int):
374        retryer = retryers.constant_retryer(
375            wait_fixed=datetime.timedelta(seconds=retry_wait_second),
376            timeout=datetime.timedelta(seconds=timeout_second),
377            retry_on_exceptions=(TdPropagationRetryableError,),
378            logger=logger,
379            log_level=logging.INFO)
380        try:
381            for attempt in retryer:
382                with attempt:
383                    self.assertSuccessfulRpcs(test_client)
384                    raw_config = test_client.csds.fetch_client_status(
385                        log_level=logging.INFO)
386                    dumped_config = xds_url_map_testcase.DumpedXdsConfig(
387                        json_format.MessageToDict(raw_config))
388                    route_config_version = dumped_config.rds_version
389                    if previous_route_config_version == route_config_version:
390                        logger.info(
391                            'Routing config not propagated yet. Retrying.')
392                        raise TdPropagationRetryableError(
393                            "CSDS not get updated routing config corresponding"
394                            " to the second set of url maps")
395                    else:
396                        self.assertSuccessfulRpcs(test_client)
397                        logger.info(
398                            ('[SUCCESS] Confirmed successful RPC with the '
399                             'updated routing config, version=%s'),
400                            route_config_version)
401        except retryers.RetryError as retry_error:
402            logger.info(
403                ('Retry exhausted. TD routing config propagation failed after '
404                 'timeout %ds. Last seen client config dump: %s'),
405                timeout_second, dumped_config)
406            raise retry_error
407
408    def assertFailedRpcs(self,
409                         test_client: XdsTestClient,
410                         num_rpcs: Optional[int] = 100):
411        lb_stats = self.getClientRpcStats(test_client, num_rpcs)
412        failed = int(lb_stats.num_failures)
413        self.assertEqual(
414            failed,
415            num_rpcs,
416            msg=f'Expected all RPCs to fail: {failed} of {num_rpcs} failed')
417
418    @staticmethod
419    def getClientRpcStats(test_client: XdsTestClient,
420                          num_rpcs: int) -> LoadBalancerStatsResponse:
421        lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs)
422        logger.info(
423            'Received LoadBalancerStatsResponse from test client %s:\n%s',
424            test_client.hostname, lb_stats)
425        return lb_stats
426
427    def assertAllBackendsReceivedRpcs(self, lb_stats):
428        # TODO(sergiitk): assert backends length
429        for backend, rpcs_count in lb_stats.rpcs_by_peer.items():
430            self.assertGreater(
431                int(rpcs_count),
432                0,
433                msg=f'Backend {backend} did not receive a single RPC')
434
435
436class IsolatedXdsKubernetesTestCase(XdsKubernetesBaseTestCase,
437                                    metaclass=abc.ABCMeta):
438    """Isolated test case.
439
440    Base class for tests cases where infra resources are created before
441    each test, and destroyed after.
442    """
443
444    def setUp(self):
445        """Hook method for setting up the test fixture before exercising it."""
446        super().setUp()
447
448        if self.resource_suffix_randomize:
449            self.resource_suffix = helpers_rand.random_resource_suffix()
450        logger.info('Test run resource prefix: %s, suffix: %s',
451                    self.resource_prefix, self.resource_suffix)
452
453        # TD Manager
454        self.td = self.initTrafficDirectorManager()
455
456        # Test Server runner
457        self.server_namespace = KubernetesServerRunner.make_namespace_name(
458            self.resource_prefix, self.resource_suffix)
459        self.server_runner = self.initKubernetesServerRunner()
460
461        # Test Client runner
462        self.client_namespace = KubernetesClientRunner.make_namespace_name(
463            self.resource_prefix, self.resource_suffix)
464        self.client_runner = self.initKubernetesClientRunner()
465
466        # Ensures the firewall exist
467        if self.ensure_firewall:
468            self.td.create_firewall_rule(
469                allowed_ports=self.firewall_allowed_ports)
470
471        # Randomize xds port, when it's set to 0
472        if self.server_xds_port == 0:
473            # TODO(sergiitk): this is prone to race conditions:
474            #  The port might not me taken now, but there's not guarantee
475            #  it won't be taken until the tests get to creating
476            #  forwarding rule. This check is better than nothing,
477            #  but we should find a better approach.
478            self.server_xds_port = self.td.find_unused_forwarding_rule_port()
479            logger.info('Found unused xds port: %s', self.server_xds_port)
480
481    @abc.abstractmethod
482    def initTrafficDirectorManager(self) -> TrafficDirectorManager:
483        raise NotImplementedError
484
485    @abc.abstractmethod
486    def initKubernetesServerRunner(self) -> KubernetesServerRunner:
487        raise NotImplementedError
488
489    @abc.abstractmethod
490    def initKubernetesClientRunner(self) -> KubernetesClientRunner:
491        raise NotImplementedError
492
493    def tearDown(self):
494        logger.info('----- TestMethod %s teardown -----', self.id())
495        logger.debug('Getting pods restart times')
496        client_restarts: int = 0
497        server_restarts: int = 0
498        try:
499            client_restarts = self.client_runner.get_pod_restarts(
500                self.client_runner.deployment)
501            server_restarts = self.server_runner.get_pod_restarts(
502                self.server_runner.deployment)
503        except (retryers.RetryError, k8s.NotFound) as e:
504            logger.exception(e)
505
506        retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10),
507                                            attempts=3,
508                                            log_level=logging.INFO)
509        try:
510            retryer(self.cleanup)
511        except retryers.RetryError:
512            logger.exception('Got error during teardown')
513        finally:
514            logger.info('----- Test client/server logs -----')
515            self.client_runner.logs_explorer_run_history_links()
516            self.server_runner.logs_explorer_run_history_links()
517
518            # Fail if any of the pods restarted.
519            self.assertEqual(
520                client_restarts,
521                0,
522                msg=
523                ('Client pods unexpectedly restarted'
524                 f' {client_restarts} times during test.'
525                 ' In most cases, this is caused by the test client app crash.'
526                ))
527            self.assertEqual(
528                server_restarts,
529                0,
530                msg=
531                ('Server pods unexpectedly restarted'
532                 f' {server_restarts} times during test.'
533                 ' In most cases, this is caused by the test client app crash.'
534                ))
535
536    def cleanup(self):
537        self.td.cleanup(force=self.force_cleanup)
538        self.client_runner.cleanup(force=self.force_cleanup)
539        self.server_runner.cleanup(force=self.force_cleanup,
540                                   force_namespace=self.force_cleanup)
541
542
543class RegularXdsKubernetesTestCase(IsolatedXdsKubernetesTestCase):
544    """Regular test case base class for testing PSM features in isolation."""
545
546    @classmethod
547    def setUpClass(cls):
548        """Hook method for setting up class fixture before running tests in
549        the class.
550        """
551        super().setUpClass()
552        if cls.server_maintenance_port is None:
553            cls.server_maintenance_port = \
554                KubernetesServerRunner.DEFAULT_MAINTENANCE_PORT
555
556    def initTrafficDirectorManager(self) -> TrafficDirectorManager:
557        return TrafficDirectorManager(
558            self.gcp_api_manager,
559            project=self.project,
560            resource_prefix=self.resource_prefix,
561            resource_suffix=self.resource_suffix,
562            network=self.network,
563            compute_api_version=self.compute_api_version)
564
565    def initKubernetesServerRunner(self) -> KubernetesServerRunner:
566        return KubernetesServerRunner(
567            k8s.KubernetesNamespace(self.k8s_api_manager,
568                                    self.server_namespace),
569            deployment_name=self.server_name,
570            image_name=self.server_image,
571            td_bootstrap_image=self.td_bootstrap_image,
572            gcp_project=self.project,
573            gcp_api_manager=self.gcp_api_manager,
574            gcp_service_account=self.gcp_service_account,
575            xds_server_uri=self.xds_server_uri,
576            network=self.network,
577            debug_use_port_forwarding=self.debug_use_port_forwarding,
578            enable_workload_identity=self.enable_workload_identity)
579
580    def initKubernetesClientRunner(self) -> KubernetesClientRunner:
581        return KubernetesClientRunner(
582            k8s.KubernetesNamespace(self.k8s_api_manager,
583                                    self.client_namespace),
584            deployment_name=self.client_name,
585            image_name=self.client_image,
586            td_bootstrap_image=self.td_bootstrap_image,
587            gcp_project=self.project,
588            gcp_api_manager=self.gcp_api_manager,
589            gcp_service_account=self.gcp_service_account,
590            xds_server_uri=self.xds_server_uri,
591            network=self.network,
592            debug_use_port_forwarding=self.debug_use_port_forwarding,
593            enable_workload_identity=self.enable_workload_identity,
594            stats_port=self.client_port,
595            reuse_namespace=self.server_namespace == self.client_namespace)
596
597    def startTestServers(self,
598                         replica_count=1,
599                         server_runner=None,
600                         **kwargs) -> List[XdsTestServer]:
601        if server_runner is None:
602            server_runner = self.server_runner
603        test_servers = server_runner.run(
604            replica_count=replica_count,
605            test_port=self.server_port,
606            maintenance_port=self.server_maintenance_port,
607            **kwargs)
608        for test_server in test_servers:
609            test_server.set_xds_address(self.server_xds_host,
610                                        self.server_xds_port)
611        return test_servers
612
613    def startTestClient(self, test_server: XdsTestServer,
614                        **kwargs) -> XdsTestClient:
615        test_client = self.client_runner.run(server_target=test_server.xds_uri,
616                                             **kwargs)
617        test_client.wait_for_active_server_channel()
618        return test_client
619
620
621class AppNetXdsKubernetesTestCase(RegularXdsKubernetesTestCase):
622    td: TrafficDirectorAppNetManager
623
624    def initTrafficDirectorManager(self) -> TrafficDirectorAppNetManager:
625        return TrafficDirectorAppNetManager(
626            self.gcp_api_manager,
627            project=self.project,
628            resource_prefix=self.resource_prefix,
629            resource_suffix=self.resource_suffix,
630            network=self.network,
631            compute_api_version=self.compute_api_version)
632
633
634class SecurityXdsKubernetesTestCase(IsolatedXdsKubernetesTestCase):
635    """Test case base class for testing PSM security features in isolation."""
636    td: TrafficDirectorSecureManager
637
638    class SecurityMode(enum.Enum):
639        MTLS = enum.auto()
640        TLS = enum.auto()
641        PLAINTEXT = enum.auto()
642
643    @classmethod
644    def setUpClass(cls):
645        """Hook method for setting up class fixture before running tests in
646        the class.
647        """
648        super().setUpClass()
649        if cls.server_maintenance_port is None:
650            # In secure mode, the maintenance port is different from
651            # the test port to keep it insecure, and make
652            # Health Checks and Channelz tests available.
653            # When not provided, use explicit numeric port value, so
654            # Backend Health Checks are created on a fixed port.
655            cls.server_maintenance_port = \
656                KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT
657
658    def initTrafficDirectorManager(self) -> TrafficDirectorSecureManager:
659        return TrafficDirectorSecureManager(
660            self.gcp_api_manager,
661            project=self.project,
662            resource_prefix=self.resource_prefix,
663            resource_suffix=self.resource_suffix,
664            network=self.network,
665            compute_api_version=self.compute_api_version)
666
667    def initKubernetesServerRunner(self) -> KubernetesServerRunner:
668        return KubernetesServerRunner(
669            k8s.KubernetesNamespace(self.k8s_api_manager,
670                                    self.server_namespace),
671            deployment_name=self.server_name,
672            image_name=self.server_image,
673            td_bootstrap_image=self.td_bootstrap_image,
674            gcp_project=self.project,
675            gcp_api_manager=self.gcp_api_manager,
676            gcp_service_account=self.gcp_service_account,
677            network=self.network,
678            xds_server_uri=self.xds_server_uri,
679            deployment_template='server-secure.deployment.yaml',
680            debug_use_port_forwarding=self.debug_use_port_forwarding)
681
682    def initKubernetesClientRunner(self) -> KubernetesClientRunner:
683        return KubernetesClientRunner(
684            k8s.KubernetesNamespace(self.k8s_api_manager,
685                                    self.client_namespace),
686            deployment_name=self.client_name,
687            image_name=self.client_image,
688            td_bootstrap_image=self.td_bootstrap_image,
689            gcp_project=self.project,
690            gcp_api_manager=self.gcp_api_manager,
691            gcp_service_account=self.gcp_service_account,
692            xds_server_uri=self.xds_server_uri,
693            network=self.network,
694            deployment_template='client-secure.deployment.yaml',
695            stats_port=self.client_port,
696            reuse_namespace=self.server_namespace == self.client_namespace,
697            debug_use_port_forwarding=self.debug_use_port_forwarding)
698
699    def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer:
700        test_server = self.server_runner.run(
701            replica_count=replica_count,
702            test_port=self.server_port,
703            maintenance_port=self.server_maintenance_port,
704            secure_mode=True,
705            **kwargs)[0]
706        test_server.set_xds_address(self.server_xds_host, self.server_xds_port)
707        return test_server
708
709    def setupSecurityPolicies(self, *, server_tls, server_mtls, client_tls,
710                              client_mtls):
711        self.td.setup_client_security(server_namespace=self.server_namespace,
712                                      server_name=self.server_name,
713                                      tls=client_tls,
714                                      mtls=client_mtls)
715        self.td.setup_server_security(server_namespace=self.server_namespace,
716                                      server_name=self.server_name,
717                                      server_port=self.server_port,
718                                      tls=server_tls,
719                                      mtls=server_mtls)
720
721    def startSecureTestClient(self,
722                              test_server: XdsTestServer,
723                              *,
724                              wait_for_active_server_channel=True,
725                              **kwargs) -> XdsTestClient:
726        test_client = self.client_runner.run(server_target=test_server.xds_uri,
727                                             secure_mode=True,
728                                             **kwargs)
729        if wait_for_active_server_channel:
730            test_client.wait_for_active_server_channel()
731        return test_client
732
733    def assertTestAppSecurity(self, mode: SecurityMode,
734                              test_client: XdsTestClient,
735                              test_server: XdsTestServer):
736        client_socket, server_socket = self.getConnectedSockets(
737            test_client, test_server)
738        server_security: grpc_channelz.Security = server_socket.security
739        client_security: grpc_channelz.Security = client_socket.security
740        logger.info('Server certs: %s', self.debug_sock_certs(server_security))
741        logger.info('Client certs: %s', self.debug_sock_certs(client_security))
742
743        if mode is self.SecurityMode.MTLS:
744            self.assertSecurityMtls(client_security, server_security)
745        elif mode is self.SecurityMode.TLS:
746            self.assertSecurityTls(client_security, server_security)
747        elif mode is self.SecurityMode.PLAINTEXT:
748            self.assertSecurityPlaintext(client_security, server_security)
749        else:
750            raise TypeError('Incorrect security mode')
751
752    def assertSecurityMtls(self, client_security: grpc_channelz.Security,
753                           server_security: grpc_channelz.Security):
754        self.assertEqual(client_security.WhichOneof('model'),
755                         'tls',
756                         msg='(mTLS) Client socket security model must be TLS')
757        self.assertEqual(server_security.WhichOneof('model'),
758                         'tls',
759                         msg='(mTLS) Server socket security model must be TLS')
760        server_tls, client_tls = server_security.tls, client_security.tls
761
762        # Confirm regular TLS: server local cert == client remote cert
763        self.assertNotEmpty(client_tls.remote_certificate,
764                            msg="(mTLS) Client remote certificate is missing")
765        if self.check_local_certs:
766            self.assertNotEmpty(
767                server_tls.local_certificate,
768                msg="(mTLS) Server local certificate is missing")
769            self.assertEqual(
770                server_tls.local_certificate,
771                client_tls.remote_certificate,
772                msg="(mTLS) Server local certificate must match client's "
773                "remote certificate")
774
775        # mTLS: server remote cert == client local cert
776        self.assertNotEmpty(server_tls.remote_certificate,
777                            msg="(mTLS) Server remote certificate is missing")
778        if self.check_local_certs:
779            self.assertNotEmpty(
780                client_tls.local_certificate,
781                msg="(mTLS) Client local certificate is missing")
782            self.assertEqual(
783                server_tls.remote_certificate,
784                client_tls.local_certificate,
785                msg="(mTLS) Server remote certificate must match client's "
786                "local certificate")
787
788    def assertSecurityTls(self, client_security: grpc_channelz.Security,
789                          server_security: grpc_channelz.Security):
790        self.assertEqual(client_security.WhichOneof('model'),
791                         'tls',
792                         msg='(TLS) Client socket security model must be TLS')
793        self.assertEqual(server_security.WhichOneof('model'),
794                         'tls',
795                         msg='(TLS) Server socket security model must be TLS')
796        server_tls, client_tls = server_security.tls, client_security.tls
797
798        # Regular TLS: server local cert == client remote cert
799        self.assertNotEmpty(client_tls.remote_certificate,
800                            msg="(TLS) Client remote certificate is missing")
801        if self.check_local_certs:
802            self.assertNotEmpty(server_tls.local_certificate,
803                                msg="(TLS) Server local certificate is missing")
804            self.assertEqual(
805                server_tls.local_certificate,
806                client_tls.remote_certificate,
807                msg="(TLS) Server local certificate must match client "
808                "remote certificate")
809
810        # mTLS must not be used
811        self.assertEmpty(
812            server_tls.remote_certificate,
813            msg="(TLS) Server remote certificate must be empty in TLS mode. "
814            "Is server security incorrectly configured for mTLS?")
815        self.assertEmpty(
816            client_tls.local_certificate,
817            msg="(TLS) Client local certificate must be empty in TLS mode. "
818            "Is client security incorrectly configured for mTLS?")
819
820    def assertSecurityPlaintext(self, client_security, server_security):
821        server_tls, client_tls = server_security.tls, client_security.tls
822        # Not TLS
823        self.assertEmpty(
824            server_tls.local_certificate,
825            msg="(Plaintext) Server local certificate must be empty.")
826        self.assertEmpty(
827            client_tls.local_certificate,
828            msg="(Plaintext) Client local certificate must be empty.")
829
830        # Not mTLS
831        self.assertEmpty(
832            server_tls.remote_certificate,
833            msg="(Plaintext) Server remote certificate must be empty.")
834        self.assertEmpty(
835            client_tls.local_certificate,
836            msg="(Plaintext) Client local certificate must be empty.")
837
838    def assertClientCannotReachServerRepeatedly(
839            self,
840            test_client: XdsTestClient,
841            *,
842            times: Optional[int] = None,
843            delay: Optional[_timedelta] = None):
844        """
845        Asserts that the client repeatedly cannot reach the server.
846
847        With negative tests we can't be absolutely certain expected failure
848        state is not caused by something else.
849        To mitigate for this, we repeat the checks several times, and expect
850        all of them to succeed.
851
852        This is useful in case the channel eventually stabilizes, and RPCs pass.
853
854        Args:
855            test_client: An instance of XdsTestClient
856            times: Optional; A positive number of times to confirm that
857                the server is unreachable. Defaults to `3` attempts.
858            delay: Optional; Specifies how long to wait before the next check.
859                Defaults to `10` seconds.
860        """
861        if times is None or times < 1:
862            times = 3
863        if delay is None:
864            delay = _timedelta(seconds=10)
865
866        for i in range(1, times + 1):
867            self.assertClientCannotReachServer(test_client)
868            if i < times:
869                logger.info('Check %s passed, waiting %s before the next check',
870                            i, delay)
871                time.sleep(delay.total_seconds())
872
873    def assertClientCannotReachServer(self, test_client: XdsTestClient):
874        self.assertClientChannelFailed(test_client)
875        self.assertFailedRpcs(test_client)
876
877    def assertClientChannelFailed(self, test_client: XdsTestClient):
878        channel = test_client.wait_for_server_channel_state(
879            state=_ChannelState.TRANSIENT_FAILURE)
880        subchannels = list(
881            test_client.channelz.list_channel_subchannels(channel))
882        self.assertLen(subchannels,
883                       1,
884                       msg="Client channel must have exactly one subchannel "
885                       "in state TRANSIENT_FAILURE.")
886
887    @staticmethod
888    def getConnectedSockets(
889        test_client: XdsTestClient, test_server: XdsTestServer
890    ) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]:
891        client_sock = test_client.get_active_server_channel_socket()
892        server_sock = test_server.get_server_socket_matching_client(client_sock)
893        return client_sock, server_sock
894
895    @classmethod
896    def debug_sock_certs(cls, security: grpc_channelz.Security):
897        if security.WhichOneof('model') == 'other':
898            return f'other: <{security.other.name}={security.other.value}>'
899
900        return (f'local: <{cls.debug_cert(security.tls.local_certificate)}>, '
901                f'remote: <{cls.debug_cert(security.tls.remote_certificate)}>')
902
903    @staticmethod
904    def debug_cert(cert):
905        if not cert:
906            return 'missing'
907        sha1 = hashlib.sha1(cert)
908        return f'sha1={sha1.hexdigest()}, len={len(cert)}'
909