1# Copyright 2020 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import abc 15import contextlib 16import datetime 17import enum 18import hashlib 19import logging 20import re 21import signal 22import time 23from types import FrameType 24from typing import Any, Callable, List, Optional, Tuple, Union 25 26from absl import flags 27from absl.testing import absltest 28from google.protobuf import json_format 29import grpc 30 31from framework import xds_flags 32from framework import xds_k8s_flags 33from framework import xds_url_map_testcase 34from framework.helpers import grpc as helpers_grpc 35from framework.helpers import rand as helpers_rand 36from framework.helpers import retryers 37from framework.helpers import skips 38from framework.infrastructure import gcp 39from framework.infrastructure import k8s 40from framework.infrastructure import traffic_director 41from framework.rpc import grpc_channelz 42from framework.rpc import grpc_csds 43from framework.rpc import grpc_testing 44from framework.test_app import client_app 45from framework.test_app import server_app 46from framework.test_app.runners.k8s import k8s_xds_client_runner 47from framework.test_app.runners.k8s import k8s_xds_server_runner 48 49logger = logging.getLogger(__name__) 50# TODO(yashkt): We will no longer need this flag once Core exposes local certs 51# from channelz 52_CHECK_LOCAL_CERTS = flags.DEFINE_bool( 53 "check_local_certs", 54 default=True, 55 help="Security Tests also check the value of local certs") 56flags.adopt_module_key_flags(xds_flags) 57flags.adopt_module_key_flags(xds_k8s_flags) 58 59# Type aliases 60TrafficDirectorManager = traffic_director.TrafficDirectorManager 61TrafficDirectorAppNetManager = traffic_director.TrafficDirectorAppNetManager 62TrafficDirectorSecureManager = traffic_director.TrafficDirectorSecureManager 63XdsTestServer = server_app.XdsTestServer 64XdsTestClient = client_app.XdsTestClient 65KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner 66KubernetesClientRunner = k8s_xds_client_runner.KubernetesClientRunner 67LoadBalancerStatsResponse = grpc_testing.LoadBalancerStatsResponse 68_ChannelState = grpc_channelz.ChannelState 69_timedelta = datetime.timedelta 70ClientConfig = grpc_csds.ClientConfig 71# pylint complains about signal.Signals for some reason. 72_SignalNum = Union[int, signal.Signals] # pylint: disable=no-member 73_SignalHandler = Callable[[_SignalNum, Optional[FrameType]], Any] 74 75_TD_CONFIG_MAX_WAIT_SEC = 600 76 77 78class TdPropagationRetryableError(Exception): 79 """Indicates that TD config hasn't propagated yet, and it's safe to retry""" 80 81 82class XdsKubernetesBaseTestCase(absltest.TestCase): 83 lang_spec: skips.TestConfig 84 client_namespace: str 85 client_runner: KubernetesClientRunner 86 ensure_firewall: bool 87 force_cleanup: bool 88 gcp_api_manager: gcp.api.GcpApiManager 89 gcp_service_account: Optional[str] 90 k8s_api_manager: k8s.KubernetesApiManager 91 secondary_k8s_api_manager: k8s.KubernetesApiManager 92 network: str 93 project: str 94 resource_prefix: str 95 resource_suffix: str = '' 96 # Whether to randomize resources names for each test by appending a 97 # unique suffix. 98 resource_suffix_randomize: bool = True 99 server_maintenance_port: Optional[int] 100 server_namespace: str 101 server_runner: KubernetesServerRunner 102 server_xds_host: str 103 server_xds_port: int 104 td: TrafficDirectorManager 105 td_bootstrap_image: str 106 _prev_sigint_handler: Optional[_SignalHandler] = None 107 _handling_sigint: bool = False 108 109 @staticmethod 110 def is_supported(config: skips.TestConfig) -> bool: 111 """Overridden by the test class to decide if the config is supported. 112 113 Returns: 114 A bool indicates if the given config is supported. 115 """ 116 del config 117 return True 118 119 @classmethod 120 def setUpClass(cls): 121 """Hook method for setting up class fixture before running tests in 122 the class. 123 """ 124 logger.info('----- Testing %s -----', cls.__name__) 125 logger.info('Logs timezone: %s', time.localtime().tm_zone) 126 127 # Raises unittest.SkipTest if given client/server/version does not 128 # support current test case. 129 cls.lang_spec = skips.evaluate_test_config(cls.is_supported) 130 131 # Must be called before KubernetesApiManager or GcpApiManager init. 132 xds_flags.set_socket_default_timeout_from_flag() 133 134 # GCP 135 cls.project = xds_flags.PROJECT.value 136 cls.network = xds_flags.NETWORK.value 137 cls.gcp_service_account = xds_k8s_flags.GCP_SERVICE_ACCOUNT.value 138 cls.td_bootstrap_image = xds_k8s_flags.TD_BOOTSTRAP_IMAGE.value 139 cls.xds_server_uri = xds_flags.XDS_SERVER_URI.value 140 cls.ensure_firewall = xds_flags.ENSURE_FIREWALL.value 141 cls.firewall_allowed_ports = xds_flags.FIREWALL_ALLOWED_PORTS.value 142 cls.compute_api_version = xds_flags.COMPUTE_API_VERSION.value 143 144 # Resource names. 145 cls.resource_prefix = xds_flags.RESOURCE_PREFIX.value 146 if xds_flags.RESOURCE_SUFFIX.value is not None: 147 cls.resource_suffix_randomize = False 148 cls.resource_suffix = xds_flags.RESOURCE_SUFFIX.value 149 150 # Test server 151 cls.server_image = xds_k8s_flags.SERVER_IMAGE.value 152 cls.server_name = xds_flags.SERVER_NAME.value 153 cls.server_port = xds_flags.SERVER_PORT.value 154 cls.server_maintenance_port = xds_flags.SERVER_MAINTENANCE_PORT.value 155 cls.server_xds_host = xds_flags.SERVER_NAME.value 156 cls.server_xds_port = xds_flags.SERVER_XDS_PORT.value 157 158 # Test client 159 cls.client_image = xds_k8s_flags.CLIENT_IMAGE.value 160 cls.client_name = xds_flags.CLIENT_NAME.value 161 cls.client_port = xds_flags.CLIENT_PORT.value 162 163 # Test suite settings 164 cls.force_cleanup = xds_flags.FORCE_CLEANUP.value 165 cls.debug_use_port_forwarding = \ 166 xds_k8s_flags.DEBUG_USE_PORT_FORWARDING.value 167 cls.enable_workload_identity = \ 168 xds_k8s_flags.ENABLE_WORKLOAD_IDENTITY.value 169 cls.check_local_certs = _CHECK_LOCAL_CERTS.value 170 171 # Resource managers 172 cls.k8s_api_manager = k8s.KubernetesApiManager( 173 xds_k8s_flags.KUBE_CONTEXT.value) 174 cls.secondary_k8s_api_manager = k8s.KubernetesApiManager( 175 xds_k8s_flags.SECONDARY_KUBE_CONTEXT.value) 176 cls.gcp_api_manager = gcp.api.GcpApiManager() 177 178 @classmethod 179 def tearDownClass(cls): 180 cls.k8s_api_manager.close() 181 cls.secondary_k8s_api_manager.close() 182 cls.gcp_api_manager.close() 183 184 def setUp(self): 185 self._prev_sigint_handler = signal.signal(signal.SIGINT, 186 self.handle_sigint) 187 188 def handle_sigint(self, signalnum: _SignalNum, 189 frame: Optional[FrameType]) -> None: 190 logger.info('Caught Ctrl+C, cleaning up...') 191 self._handling_sigint = True 192 # Force resource cleanup by their name. Addresses the case where ctrl-c 193 # is pressed while waiting for the resource creation. 194 self.force_cleanup = True 195 self.tearDown() 196 self.tearDownClass() 197 self._handling_sigint = False 198 if self._prev_sigint_handler is not None: 199 signal.signal(signal.SIGINT, self._prev_sigint_handler) 200 raise KeyboardInterrupt 201 202 @contextlib.contextmanager 203 def subTest(self, msg, **params): # noqa pylint: disable=signature-differs 204 logger.info('--- Starting subTest %s.%s ---', self.id(), msg) 205 try: 206 yield super().subTest(msg, **params) 207 finally: 208 if not self._handling_sigint: 209 logger.info('--- Finished subTest %s.%s ---', self.id(), msg) 210 211 def setupTrafficDirectorGrpc(self): 212 self.td.setup_for_grpc(self.server_xds_host, 213 self.server_xds_port, 214 health_check_port=self.server_maintenance_port) 215 216 def setupServerBackends(self, 217 *, 218 wait_for_healthy_status=True, 219 server_runner=None, 220 max_rate_per_endpoint: Optional[int] = None): 221 if server_runner is None: 222 server_runner = self.server_runner 223 # Load Backends 224 neg_name, neg_zones = server_runner.k8s_namespace.get_service_neg( 225 server_runner.service_name, self.server_port) 226 227 # Add backends to the Backend Service 228 self.td.backend_service_add_neg_backends( 229 neg_name, neg_zones, max_rate_per_endpoint=max_rate_per_endpoint) 230 if wait_for_healthy_status: 231 self.td.wait_for_backends_healthy_status() 232 233 def removeServerBackends(self, *, server_runner=None): 234 if server_runner is None: 235 server_runner = self.server_runner 236 # Load Backends 237 neg_name, neg_zones = server_runner.k8s_namespace.get_service_neg( 238 server_runner.service_name, self.server_port) 239 240 # Remove backends from the Backend Service 241 self.td.backend_service_remove_neg_backends(neg_name, neg_zones) 242 243 def assertSuccessfulRpcs(self, 244 test_client: XdsTestClient, 245 num_rpcs: int = 100): 246 lb_stats = self.getClientRpcStats(test_client, num_rpcs) 247 self.assertAllBackendsReceivedRpcs(lb_stats) 248 failed = int(lb_stats.num_failures) 249 self.assertLessEqual( 250 failed, 251 0, 252 msg=f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed') 253 254 @staticmethod 255 def diffAccumulatedStatsPerMethod( 256 before: grpc_testing.LoadBalancerAccumulatedStatsResponse, 257 after: grpc_testing.LoadBalancerAccumulatedStatsResponse 258 ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse: 259 """Only diffs stats_per_method, as the other fields are deprecated.""" 260 diff = grpc_testing.LoadBalancerAccumulatedStatsResponse() 261 for method, method_stats in after.stats_per_method.items(): 262 for status, count in method_stats.result.items(): 263 count -= before.stats_per_method[method].result[status] 264 if count < 0: 265 raise AssertionError("Diff of count shouldn't be negative") 266 if count > 0: 267 diff.stats_per_method[method].result[status] = count 268 return diff 269 270 def assertRpcStatusCodes(self, 271 test_client: XdsTestClient, 272 *, 273 expected_status: grpc.StatusCode, 274 duration: _timedelta, 275 method: str, 276 stray_rpc_limit: int = 0) -> None: 277 """Assert all RPCs for a method are completing with a certain status.""" 278 # Sending with pre-set QPS for a period of time 279 before_stats = test_client.get_load_balancer_accumulated_stats() 280 response_type = 'LoadBalancerAccumulatedStatsResponse' 281 logging.info('Received %s from test client %s: before:\n%s', 282 response_type, test_client.hostname, before_stats) 283 time.sleep(duration.total_seconds()) 284 after_stats = test_client.get_load_balancer_accumulated_stats() 285 logging.info('Received %s from test client %s: after:\n%s', 286 response_type, test_client.hostname, after_stats) 287 288 diff_stats = self.diffAccumulatedStatsPerMethod(before_stats, 289 after_stats) 290 stats = diff_stats.stats_per_method[method] 291 for found_status_int, count in stats.result.items(): 292 found_status = helpers_grpc.status_from_int(found_status_int) 293 if found_status != expected_status and count > stray_rpc_limit: 294 self.fail(f"Expected only status" 295 f" {helpers_grpc.status_pretty(expected_status)}," 296 " but found status" 297 f" {helpers_grpc.status_pretty(found_status)}" 298 f" for method {method}:\n{diff_stats}") 299 300 expected_status_int: int = expected_status.value[0] 301 self.assertGreater( 302 stats.result[expected_status_int], 303 0, 304 msg=("Expected non-zero RPCs with status" 305 f" {helpers_grpc.status_pretty(expected_status)}" 306 f" for method {method}, got:\n{diff_stats}")) 307 308 def assertRpcsEventuallyGoToGivenServers(self, 309 test_client: XdsTestClient, 310 servers: List[XdsTestServer], 311 num_rpcs: int = 100): 312 retryer = retryers.constant_retryer( 313 wait_fixed=datetime.timedelta(seconds=1), 314 timeout=datetime.timedelta(seconds=_TD_CONFIG_MAX_WAIT_SEC), 315 log_level=logging.INFO) 316 try: 317 retryer(self._assertRpcsEventuallyGoToGivenServers, test_client, 318 servers, num_rpcs) 319 except retryers.RetryError as retry_error: 320 logger.exception( 321 'Rpcs did not go to expected servers before timeout %s', 322 _TD_CONFIG_MAX_WAIT_SEC) 323 raise retry_error 324 325 def _assertRpcsEventuallyGoToGivenServers(self, test_client: XdsTestClient, 326 servers: List[XdsTestServer], 327 num_rpcs: int): 328 server_hostnames = [server.hostname for server in servers] 329 logger.info('Verifying RPCs go to servers %s', server_hostnames) 330 lb_stats = self.getClientRpcStats(test_client, num_rpcs) 331 failed = int(lb_stats.num_failures) 332 self.assertLessEqual( 333 failed, 334 0, 335 msg=f'Expected all RPCs to succeed: {failed} of {num_rpcs} failed') 336 for server_hostname in server_hostnames: 337 self.assertIn(server_hostname, lb_stats.rpcs_by_peer, 338 f'Server {server_hostname} did not receive RPCs') 339 for server_hostname in lb_stats.rpcs_by_peer.keys(): 340 self.assertIn(server_hostname, server_hostnames, 341 f'Unexpected server {server_hostname} received RPCs') 342 343 def assertXdsConfigExists(self, test_client: XdsTestClient): 344 config = test_client.csds.fetch_client_status(log_level=logging.INFO) 345 self.assertIsNotNone(config) 346 seen = set() 347 want = frozenset([ 348 'listener_config', 349 'cluster_config', 350 'route_config', 351 'endpoint_config', 352 ]) 353 for xds_config in config.xds_config: 354 seen.add(xds_config.WhichOneof('per_xds_config')) 355 for generic_xds_config in config.generic_xds_configs: 356 if re.search(r'\.Listener$', generic_xds_config.type_url): 357 seen.add('listener_config') 358 elif re.search(r'\.RouteConfiguration$', 359 generic_xds_config.type_url): 360 seen.add('route_config') 361 elif re.search(r'\.Cluster$', generic_xds_config.type_url): 362 seen.add('cluster_config') 363 elif re.search(r'\.ClusterLoadAssignment$', 364 generic_xds_config.type_url): 365 seen.add('endpoint_config') 366 logger.debug('Received xDS config dump: %s', 367 json_format.MessageToJson(config, indent=2)) 368 self.assertSameElements(want, seen) 369 370 def assertRouteConfigUpdateTrafficHandoff( 371 self, test_client: XdsTestClient, 372 previous_route_config_version: str, retry_wait_second: int, 373 timeout_second: int): 374 retryer = retryers.constant_retryer( 375 wait_fixed=datetime.timedelta(seconds=retry_wait_second), 376 timeout=datetime.timedelta(seconds=timeout_second), 377 retry_on_exceptions=(TdPropagationRetryableError,), 378 logger=logger, 379 log_level=logging.INFO) 380 try: 381 for attempt in retryer: 382 with attempt: 383 self.assertSuccessfulRpcs(test_client) 384 raw_config = test_client.csds.fetch_client_status( 385 log_level=logging.INFO) 386 dumped_config = xds_url_map_testcase.DumpedXdsConfig( 387 json_format.MessageToDict(raw_config)) 388 route_config_version = dumped_config.rds_version 389 if previous_route_config_version == route_config_version: 390 logger.info( 391 'Routing config not propagated yet. Retrying.') 392 raise TdPropagationRetryableError( 393 "CSDS not get updated routing config corresponding" 394 " to the second set of url maps") 395 else: 396 self.assertSuccessfulRpcs(test_client) 397 logger.info( 398 ('[SUCCESS] Confirmed successful RPC with the ' 399 'updated routing config, version=%s'), 400 route_config_version) 401 except retryers.RetryError as retry_error: 402 logger.info( 403 ('Retry exhausted. TD routing config propagation failed after ' 404 'timeout %ds. Last seen client config dump: %s'), 405 timeout_second, dumped_config) 406 raise retry_error 407 408 def assertFailedRpcs(self, 409 test_client: XdsTestClient, 410 num_rpcs: Optional[int] = 100): 411 lb_stats = self.getClientRpcStats(test_client, num_rpcs) 412 failed = int(lb_stats.num_failures) 413 self.assertEqual( 414 failed, 415 num_rpcs, 416 msg=f'Expected all RPCs to fail: {failed} of {num_rpcs} failed') 417 418 @staticmethod 419 def getClientRpcStats(test_client: XdsTestClient, 420 num_rpcs: int) -> LoadBalancerStatsResponse: 421 lb_stats = test_client.get_load_balancer_stats(num_rpcs=num_rpcs) 422 logger.info( 423 'Received LoadBalancerStatsResponse from test client %s:\n%s', 424 test_client.hostname, lb_stats) 425 return lb_stats 426 427 def assertAllBackendsReceivedRpcs(self, lb_stats): 428 # TODO(sergiitk): assert backends length 429 for backend, rpcs_count in lb_stats.rpcs_by_peer.items(): 430 self.assertGreater( 431 int(rpcs_count), 432 0, 433 msg=f'Backend {backend} did not receive a single RPC') 434 435 436class IsolatedXdsKubernetesTestCase(XdsKubernetesBaseTestCase, 437 metaclass=abc.ABCMeta): 438 """Isolated test case. 439 440 Base class for tests cases where infra resources are created before 441 each test, and destroyed after. 442 """ 443 444 def setUp(self): 445 """Hook method for setting up the test fixture before exercising it.""" 446 super().setUp() 447 448 if self.resource_suffix_randomize: 449 self.resource_suffix = helpers_rand.random_resource_suffix() 450 logger.info('Test run resource prefix: %s, suffix: %s', 451 self.resource_prefix, self.resource_suffix) 452 453 # TD Manager 454 self.td = self.initTrafficDirectorManager() 455 456 # Test Server runner 457 self.server_namespace = KubernetesServerRunner.make_namespace_name( 458 self.resource_prefix, self.resource_suffix) 459 self.server_runner = self.initKubernetesServerRunner() 460 461 # Test Client runner 462 self.client_namespace = KubernetesClientRunner.make_namespace_name( 463 self.resource_prefix, self.resource_suffix) 464 self.client_runner = self.initKubernetesClientRunner() 465 466 # Ensures the firewall exist 467 if self.ensure_firewall: 468 self.td.create_firewall_rule( 469 allowed_ports=self.firewall_allowed_ports) 470 471 # Randomize xds port, when it's set to 0 472 if self.server_xds_port == 0: 473 # TODO(sergiitk): this is prone to race conditions: 474 # The port might not me taken now, but there's not guarantee 475 # it won't be taken until the tests get to creating 476 # forwarding rule. This check is better than nothing, 477 # but we should find a better approach. 478 self.server_xds_port = self.td.find_unused_forwarding_rule_port() 479 logger.info('Found unused xds port: %s', self.server_xds_port) 480 481 @abc.abstractmethod 482 def initTrafficDirectorManager(self) -> TrafficDirectorManager: 483 raise NotImplementedError 484 485 @abc.abstractmethod 486 def initKubernetesServerRunner(self) -> KubernetesServerRunner: 487 raise NotImplementedError 488 489 @abc.abstractmethod 490 def initKubernetesClientRunner(self) -> KubernetesClientRunner: 491 raise NotImplementedError 492 493 def tearDown(self): 494 logger.info('----- TestMethod %s teardown -----', self.id()) 495 logger.debug('Getting pods restart times') 496 client_restarts: int = 0 497 server_restarts: int = 0 498 try: 499 client_restarts = self.client_runner.get_pod_restarts( 500 self.client_runner.deployment) 501 server_restarts = self.server_runner.get_pod_restarts( 502 self.server_runner.deployment) 503 except (retryers.RetryError, k8s.NotFound) as e: 504 logger.exception(e) 505 506 retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10), 507 attempts=3, 508 log_level=logging.INFO) 509 try: 510 retryer(self.cleanup) 511 except retryers.RetryError: 512 logger.exception('Got error during teardown') 513 finally: 514 logger.info('----- Test client/server logs -----') 515 self.client_runner.logs_explorer_run_history_links() 516 self.server_runner.logs_explorer_run_history_links() 517 518 # Fail if any of the pods restarted. 519 self.assertEqual( 520 client_restarts, 521 0, 522 msg= 523 ('Client pods unexpectedly restarted' 524 f' {client_restarts} times during test.' 525 ' In most cases, this is caused by the test client app crash.' 526 )) 527 self.assertEqual( 528 server_restarts, 529 0, 530 msg= 531 ('Server pods unexpectedly restarted' 532 f' {server_restarts} times during test.' 533 ' In most cases, this is caused by the test client app crash.' 534 )) 535 536 def cleanup(self): 537 self.td.cleanup(force=self.force_cleanup) 538 self.client_runner.cleanup(force=self.force_cleanup) 539 self.server_runner.cleanup(force=self.force_cleanup, 540 force_namespace=self.force_cleanup) 541 542 543class RegularXdsKubernetesTestCase(IsolatedXdsKubernetesTestCase): 544 """Regular test case base class for testing PSM features in isolation.""" 545 546 @classmethod 547 def setUpClass(cls): 548 """Hook method for setting up class fixture before running tests in 549 the class. 550 """ 551 super().setUpClass() 552 if cls.server_maintenance_port is None: 553 cls.server_maintenance_port = \ 554 KubernetesServerRunner.DEFAULT_MAINTENANCE_PORT 555 556 def initTrafficDirectorManager(self) -> TrafficDirectorManager: 557 return TrafficDirectorManager( 558 self.gcp_api_manager, 559 project=self.project, 560 resource_prefix=self.resource_prefix, 561 resource_suffix=self.resource_suffix, 562 network=self.network, 563 compute_api_version=self.compute_api_version) 564 565 def initKubernetesServerRunner(self) -> KubernetesServerRunner: 566 return KubernetesServerRunner( 567 k8s.KubernetesNamespace(self.k8s_api_manager, 568 self.server_namespace), 569 deployment_name=self.server_name, 570 image_name=self.server_image, 571 td_bootstrap_image=self.td_bootstrap_image, 572 gcp_project=self.project, 573 gcp_api_manager=self.gcp_api_manager, 574 gcp_service_account=self.gcp_service_account, 575 xds_server_uri=self.xds_server_uri, 576 network=self.network, 577 debug_use_port_forwarding=self.debug_use_port_forwarding, 578 enable_workload_identity=self.enable_workload_identity) 579 580 def initKubernetesClientRunner(self) -> KubernetesClientRunner: 581 return KubernetesClientRunner( 582 k8s.KubernetesNamespace(self.k8s_api_manager, 583 self.client_namespace), 584 deployment_name=self.client_name, 585 image_name=self.client_image, 586 td_bootstrap_image=self.td_bootstrap_image, 587 gcp_project=self.project, 588 gcp_api_manager=self.gcp_api_manager, 589 gcp_service_account=self.gcp_service_account, 590 xds_server_uri=self.xds_server_uri, 591 network=self.network, 592 debug_use_port_forwarding=self.debug_use_port_forwarding, 593 enable_workload_identity=self.enable_workload_identity, 594 stats_port=self.client_port, 595 reuse_namespace=self.server_namespace == self.client_namespace) 596 597 def startTestServers(self, 598 replica_count=1, 599 server_runner=None, 600 **kwargs) -> List[XdsTestServer]: 601 if server_runner is None: 602 server_runner = self.server_runner 603 test_servers = server_runner.run( 604 replica_count=replica_count, 605 test_port=self.server_port, 606 maintenance_port=self.server_maintenance_port, 607 **kwargs) 608 for test_server in test_servers: 609 test_server.set_xds_address(self.server_xds_host, 610 self.server_xds_port) 611 return test_servers 612 613 def startTestClient(self, test_server: XdsTestServer, 614 **kwargs) -> XdsTestClient: 615 test_client = self.client_runner.run(server_target=test_server.xds_uri, 616 **kwargs) 617 test_client.wait_for_active_server_channel() 618 return test_client 619 620 621class AppNetXdsKubernetesTestCase(RegularXdsKubernetesTestCase): 622 td: TrafficDirectorAppNetManager 623 624 def initTrafficDirectorManager(self) -> TrafficDirectorAppNetManager: 625 return TrafficDirectorAppNetManager( 626 self.gcp_api_manager, 627 project=self.project, 628 resource_prefix=self.resource_prefix, 629 resource_suffix=self.resource_suffix, 630 network=self.network, 631 compute_api_version=self.compute_api_version) 632 633 634class SecurityXdsKubernetesTestCase(IsolatedXdsKubernetesTestCase): 635 """Test case base class for testing PSM security features in isolation.""" 636 td: TrafficDirectorSecureManager 637 638 class SecurityMode(enum.Enum): 639 MTLS = enum.auto() 640 TLS = enum.auto() 641 PLAINTEXT = enum.auto() 642 643 @classmethod 644 def setUpClass(cls): 645 """Hook method for setting up class fixture before running tests in 646 the class. 647 """ 648 super().setUpClass() 649 if cls.server_maintenance_port is None: 650 # In secure mode, the maintenance port is different from 651 # the test port to keep it insecure, and make 652 # Health Checks and Channelz tests available. 653 # When not provided, use explicit numeric port value, so 654 # Backend Health Checks are created on a fixed port. 655 cls.server_maintenance_port = \ 656 KubernetesServerRunner.DEFAULT_SECURE_MODE_MAINTENANCE_PORT 657 658 def initTrafficDirectorManager(self) -> TrafficDirectorSecureManager: 659 return TrafficDirectorSecureManager( 660 self.gcp_api_manager, 661 project=self.project, 662 resource_prefix=self.resource_prefix, 663 resource_suffix=self.resource_suffix, 664 network=self.network, 665 compute_api_version=self.compute_api_version) 666 667 def initKubernetesServerRunner(self) -> KubernetesServerRunner: 668 return KubernetesServerRunner( 669 k8s.KubernetesNamespace(self.k8s_api_manager, 670 self.server_namespace), 671 deployment_name=self.server_name, 672 image_name=self.server_image, 673 td_bootstrap_image=self.td_bootstrap_image, 674 gcp_project=self.project, 675 gcp_api_manager=self.gcp_api_manager, 676 gcp_service_account=self.gcp_service_account, 677 network=self.network, 678 xds_server_uri=self.xds_server_uri, 679 deployment_template='server-secure.deployment.yaml', 680 debug_use_port_forwarding=self.debug_use_port_forwarding) 681 682 def initKubernetesClientRunner(self) -> KubernetesClientRunner: 683 return KubernetesClientRunner( 684 k8s.KubernetesNamespace(self.k8s_api_manager, 685 self.client_namespace), 686 deployment_name=self.client_name, 687 image_name=self.client_image, 688 td_bootstrap_image=self.td_bootstrap_image, 689 gcp_project=self.project, 690 gcp_api_manager=self.gcp_api_manager, 691 gcp_service_account=self.gcp_service_account, 692 xds_server_uri=self.xds_server_uri, 693 network=self.network, 694 deployment_template='client-secure.deployment.yaml', 695 stats_port=self.client_port, 696 reuse_namespace=self.server_namespace == self.client_namespace, 697 debug_use_port_forwarding=self.debug_use_port_forwarding) 698 699 def startSecureTestServer(self, replica_count=1, **kwargs) -> XdsTestServer: 700 test_server = self.server_runner.run( 701 replica_count=replica_count, 702 test_port=self.server_port, 703 maintenance_port=self.server_maintenance_port, 704 secure_mode=True, 705 **kwargs)[0] 706 test_server.set_xds_address(self.server_xds_host, self.server_xds_port) 707 return test_server 708 709 def setupSecurityPolicies(self, *, server_tls, server_mtls, client_tls, 710 client_mtls): 711 self.td.setup_client_security(server_namespace=self.server_namespace, 712 server_name=self.server_name, 713 tls=client_tls, 714 mtls=client_mtls) 715 self.td.setup_server_security(server_namespace=self.server_namespace, 716 server_name=self.server_name, 717 server_port=self.server_port, 718 tls=server_tls, 719 mtls=server_mtls) 720 721 def startSecureTestClient(self, 722 test_server: XdsTestServer, 723 *, 724 wait_for_active_server_channel=True, 725 **kwargs) -> XdsTestClient: 726 test_client = self.client_runner.run(server_target=test_server.xds_uri, 727 secure_mode=True, 728 **kwargs) 729 if wait_for_active_server_channel: 730 test_client.wait_for_active_server_channel() 731 return test_client 732 733 def assertTestAppSecurity(self, mode: SecurityMode, 734 test_client: XdsTestClient, 735 test_server: XdsTestServer): 736 client_socket, server_socket = self.getConnectedSockets( 737 test_client, test_server) 738 server_security: grpc_channelz.Security = server_socket.security 739 client_security: grpc_channelz.Security = client_socket.security 740 logger.info('Server certs: %s', self.debug_sock_certs(server_security)) 741 logger.info('Client certs: %s', self.debug_sock_certs(client_security)) 742 743 if mode is self.SecurityMode.MTLS: 744 self.assertSecurityMtls(client_security, server_security) 745 elif mode is self.SecurityMode.TLS: 746 self.assertSecurityTls(client_security, server_security) 747 elif mode is self.SecurityMode.PLAINTEXT: 748 self.assertSecurityPlaintext(client_security, server_security) 749 else: 750 raise TypeError('Incorrect security mode') 751 752 def assertSecurityMtls(self, client_security: grpc_channelz.Security, 753 server_security: grpc_channelz.Security): 754 self.assertEqual(client_security.WhichOneof('model'), 755 'tls', 756 msg='(mTLS) Client socket security model must be TLS') 757 self.assertEqual(server_security.WhichOneof('model'), 758 'tls', 759 msg='(mTLS) Server socket security model must be TLS') 760 server_tls, client_tls = server_security.tls, client_security.tls 761 762 # Confirm regular TLS: server local cert == client remote cert 763 self.assertNotEmpty(client_tls.remote_certificate, 764 msg="(mTLS) Client remote certificate is missing") 765 if self.check_local_certs: 766 self.assertNotEmpty( 767 server_tls.local_certificate, 768 msg="(mTLS) Server local certificate is missing") 769 self.assertEqual( 770 server_tls.local_certificate, 771 client_tls.remote_certificate, 772 msg="(mTLS) Server local certificate must match client's " 773 "remote certificate") 774 775 # mTLS: server remote cert == client local cert 776 self.assertNotEmpty(server_tls.remote_certificate, 777 msg="(mTLS) Server remote certificate is missing") 778 if self.check_local_certs: 779 self.assertNotEmpty( 780 client_tls.local_certificate, 781 msg="(mTLS) Client local certificate is missing") 782 self.assertEqual( 783 server_tls.remote_certificate, 784 client_tls.local_certificate, 785 msg="(mTLS) Server remote certificate must match client's " 786 "local certificate") 787 788 def assertSecurityTls(self, client_security: grpc_channelz.Security, 789 server_security: grpc_channelz.Security): 790 self.assertEqual(client_security.WhichOneof('model'), 791 'tls', 792 msg='(TLS) Client socket security model must be TLS') 793 self.assertEqual(server_security.WhichOneof('model'), 794 'tls', 795 msg='(TLS) Server socket security model must be TLS') 796 server_tls, client_tls = server_security.tls, client_security.tls 797 798 # Regular TLS: server local cert == client remote cert 799 self.assertNotEmpty(client_tls.remote_certificate, 800 msg="(TLS) Client remote certificate is missing") 801 if self.check_local_certs: 802 self.assertNotEmpty(server_tls.local_certificate, 803 msg="(TLS) Server local certificate is missing") 804 self.assertEqual( 805 server_tls.local_certificate, 806 client_tls.remote_certificate, 807 msg="(TLS) Server local certificate must match client " 808 "remote certificate") 809 810 # mTLS must not be used 811 self.assertEmpty( 812 server_tls.remote_certificate, 813 msg="(TLS) Server remote certificate must be empty in TLS mode. " 814 "Is server security incorrectly configured for mTLS?") 815 self.assertEmpty( 816 client_tls.local_certificate, 817 msg="(TLS) Client local certificate must be empty in TLS mode. " 818 "Is client security incorrectly configured for mTLS?") 819 820 def assertSecurityPlaintext(self, client_security, server_security): 821 server_tls, client_tls = server_security.tls, client_security.tls 822 # Not TLS 823 self.assertEmpty( 824 server_tls.local_certificate, 825 msg="(Plaintext) Server local certificate must be empty.") 826 self.assertEmpty( 827 client_tls.local_certificate, 828 msg="(Plaintext) Client local certificate must be empty.") 829 830 # Not mTLS 831 self.assertEmpty( 832 server_tls.remote_certificate, 833 msg="(Plaintext) Server remote certificate must be empty.") 834 self.assertEmpty( 835 client_tls.local_certificate, 836 msg="(Plaintext) Client local certificate must be empty.") 837 838 def assertClientCannotReachServerRepeatedly( 839 self, 840 test_client: XdsTestClient, 841 *, 842 times: Optional[int] = None, 843 delay: Optional[_timedelta] = None): 844 """ 845 Asserts that the client repeatedly cannot reach the server. 846 847 With negative tests we can't be absolutely certain expected failure 848 state is not caused by something else. 849 To mitigate for this, we repeat the checks several times, and expect 850 all of them to succeed. 851 852 This is useful in case the channel eventually stabilizes, and RPCs pass. 853 854 Args: 855 test_client: An instance of XdsTestClient 856 times: Optional; A positive number of times to confirm that 857 the server is unreachable. Defaults to `3` attempts. 858 delay: Optional; Specifies how long to wait before the next check. 859 Defaults to `10` seconds. 860 """ 861 if times is None or times < 1: 862 times = 3 863 if delay is None: 864 delay = _timedelta(seconds=10) 865 866 for i in range(1, times + 1): 867 self.assertClientCannotReachServer(test_client) 868 if i < times: 869 logger.info('Check %s passed, waiting %s before the next check', 870 i, delay) 871 time.sleep(delay.total_seconds()) 872 873 def assertClientCannotReachServer(self, test_client: XdsTestClient): 874 self.assertClientChannelFailed(test_client) 875 self.assertFailedRpcs(test_client) 876 877 def assertClientChannelFailed(self, test_client: XdsTestClient): 878 channel = test_client.wait_for_server_channel_state( 879 state=_ChannelState.TRANSIENT_FAILURE) 880 subchannels = list( 881 test_client.channelz.list_channel_subchannels(channel)) 882 self.assertLen(subchannels, 883 1, 884 msg="Client channel must have exactly one subchannel " 885 "in state TRANSIENT_FAILURE.") 886 887 @staticmethod 888 def getConnectedSockets( 889 test_client: XdsTestClient, test_server: XdsTestServer 890 ) -> Tuple[grpc_channelz.Socket, grpc_channelz.Socket]: 891 client_sock = test_client.get_active_server_channel_socket() 892 server_sock = test_server.get_server_socket_matching_client(client_sock) 893 return client_sock, server_sock 894 895 @classmethod 896 def debug_sock_certs(cls, security: grpc_channelz.Security): 897 if security.WhichOneof('model') == 'other': 898 return f'other: <{security.other.name}={security.other.value}>' 899 900 return (f'local: <{cls.debug_cert(security.tls.local_certificate)}>, ' 901 f'remote: <{cls.debug_cert(security.tls.remote_certificate)}>') 902 903 @staticmethod 904 def debug_cert(cert): 905 if not cert: 906 return 'missing' 907 sha1 = hashlib.sha1(cert) 908 return f'sha1={sha1.hexdigest()}, len={len(cert)}' 909