1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import functools
15import logging
16import random
17from typing import Any, Dict, List, Optional, Set
18
19from framework import xds_flags
20from framework.infrastructure import gcp
21
22logger = logging.getLogger(__name__)
23
24# Type aliases
25# Compute
26_ComputeV1 = gcp.compute.ComputeV1
27GcpResource = _ComputeV1.GcpResource
28HealthCheckProtocol = _ComputeV1.HealthCheckProtocol
29ZonalGcpResource = _ComputeV1.ZonalGcpResource
30BackendServiceProtocol = _ComputeV1.BackendServiceProtocol
31_BackendGRPC = BackendServiceProtocol.GRPC
32_HealthCheckGRPC = HealthCheckProtocol.GRPC
33
34# Network Security
35_NetworkSecurityV1Beta1 = gcp.network_security.NetworkSecurityV1Beta1
36ServerTlsPolicy = gcp.network_security.ServerTlsPolicy
37ClientTlsPolicy = gcp.network_security.ClientTlsPolicy
38AuthorizationPolicy = gcp.network_security.AuthorizationPolicy
39
40# Network Services
41_NetworkServicesV1Alpha1 = gcp.network_services.NetworkServicesV1Alpha1
42_NetworkServicesV1Beta1 = gcp.network_services.NetworkServicesV1Beta1
43EndpointPolicy = gcp.network_services.EndpointPolicy
44GrpcRoute = gcp.network_services.GrpcRoute
45Mesh = gcp.network_services.Mesh
46
47# Testing metadata consts
48TEST_AFFINITY_METADATA_KEY = 'xds_md'
49
50
51class TrafficDirectorManager:  # pylint: disable=too-many-public-methods
52    compute: _ComputeV1
53    resource_prefix: str
54    resource_suffix: str
55
56    BACKEND_SERVICE_NAME = "backend-service"
57    ALTERNATIVE_BACKEND_SERVICE_NAME = "backend-service-alt"
58    AFFINITY_BACKEND_SERVICE_NAME = "backend-service-affinity"
59    HEALTH_CHECK_NAME = "health-check"
60    URL_MAP_NAME = "url-map"
61    ALTERNATIVE_URL_MAP_NAME = "url-map-alt"
62    URL_MAP_PATH_MATCHER_NAME = "path-matcher"
63    TARGET_PROXY_NAME = "target-proxy"
64    ALTERNATIVE_TARGET_PROXY_NAME = "target-proxy-alt"
65    FORWARDING_RULE_NAME = "forwarding-rule"
66    ALTERNATIVE_FORWARDING_RULE_NAME = "forwarding-rule-alt"
67    FIREWALL_RULE_NAME = "allow-health-checks"
68
69    def __init__(
70        self,
71        gcp_api_manager: gcp.api.GcpApiManager,
72        project: str,
73        *,
74        resource_prefix: str,
75        resource_suffix: str,
76        network: str = 'default',
77        compute_api_version: str = 'v1',
78    ):
79        # API
80        self.compute = _ComputeV1(gcp_api_manager,
81                                  project,
82                                  version=compute_api_version)
83
84        # Settings
85        self.project: str = project
86        self.network: str = network
87        self.resource_prefix: str = resource_prefix
88        self.resource_suffix: str = resource_suffix
89
90        # Managed resources
91        self.health_check: Optional[GcpResource] = None
92        self.backend_service: Optional[GcpResource] = None
93        # TODO(sergiitk): remove this flag once backend service resource loaded
94        self.backend_service_protocol: Optional[BackendServiceProtocol] = None
95        self.url_map: Optional[GcpResource] = None
96        self.alternative_url_map: Optional[GcpResource] = None
97        self.firewall_rule: Optional[GcpResource] = None
98        self.target_proxy: Optional[GcpResource] = None
99        # TODO(sergiitk): remove this flag once target proxy resource loaded
100        self.target_proxy_is_http: bool = False
101        self.alternative_target_proxy: Optional[GcpResource] = None
102        self.forwarding_rule: Optional[GcpResource] = None
103        self.alternative_forwarding_rule: Optional[GcpResource] = None
104        self.backends: Set[ZonalGcpResource] = set()
105        self.alternative_backend_service: Optional[GcpResource] = None
106        # TODO(sergiitk): remove this flag once backend service resource loaded
107        self.alternative_backend_service_protocol: Optional[
108            BackendServiceProtocol] = None
109        self.alternative_backends: Set[ZonalGcpResource] = set()
110        self.affinity_backend_service: Optional[GcpResource] = None
111        # TODO(sergiitk): remove this flag once backend service resource loaded
112        self.affinity_backend_service_protocol: Optional[
113            BackendServiceProtocol] = None
114        self.affinity_backends: Set[ZonalGcpResource] = set()
115
116    @property
117    def network_url(self):
118        return f'global/networks/{self.network}'
119
120    def setup_for_grpc(
121            self,
122            service_host,
123            service_port,
124            *,
125            backend_protocol: Optional[BackendServiceProtocol] = _BackendGRPC,
126            health_check_port: Optional[int] = None):
127        self.setup_backend_for_grpc(protocol=backend_protocol,
128                                    health_check_port=health_check_port)
129        self.setup_routing_rule_map_for_grpc(service_host, service_port)
130
131    def setup_backend_for_grpc(
132            self,
133            *,
134            protocol: Optional[BackendServiceProtocol] = _BackendGRPC,
135            health_check_port: Optional[int] = None):
136        self.create_health_check(port=health_check_port)
137        self.create_backend_service(protocol)
138
139    def setup_routing_rule_map_for_grpc(self, service_host, service_port):
140        self.create_url_map(service_host, service_port)
141        self.create_target_proxy()
142        self.create_forwarding_rule(service_port)
143
144    def cleanup(self, *, force=False):
145        # Cleanup in the reverse order of creation
146        self.delete_forwarding_rule(force=force)
147        self.delete_alternative_forwarding_rule(force=force)
148        self.delete_target_http_proxy(force=force)
149        self.delete_target_grpc_proxy(force=force)
150        self.delete_alternative_target_grpc_proxy(force=force)
151        self.delete_url_map(force=force)
152        self.delete_alternative_url_map(force=force)
153        self.delete_backend_service(force=force)
154        self.delete_alternative_backend_service(force=force)
155        self.delete_affinity_backend_service(force=force)
156        self.delete_health_check(force=force)
157
158    @functools.lru_cache(None)
159    def make_resource_name(self, name: str) -> str:
160        """Make dash-separated resource name with resource prefix and suffix."""
161        parts = [self.resource_prefix, name]
162        # Avoid trailing dash when the suffix is empty.
163        if self.resource_suffix:
164            parts.append(self.resource_suffix)
165        return '-'.join(parts)
166
167    def create_health_check(
168            self,
169            *,
170            protocol: Optional[HealthCheckProtocol] = _HealthCheckGRPC,
171            port: Optional[int] = None):
172        if self.health_check:
173            raise ValueError(f'Health check {self.health_check.name} '
174                             'already created, delete it first')
175        if protocol is None:
176            protocol = _HealthCheckGRPC
177
178        name = self.make_resource_name(self.HEALTH_CHECK_NAME)
179        logger.info('Creating %s Health Check "%s"', protocol.name, name)
180        resource = self.compute.create_health_check(name, protocol, port=port)
181        self.health_check = resource
182
183    def delete_health_check(self, force=False):
184        if force:
185            name = self.make_resource_name(self.HEALTH_CHECK_NAME)
186        elif self.health_check:
187            name = self.health_check.name
188        else:
189            return
190        logger.info('Deleting Health Check "%s"', name)
191        self.compute.delete_health_check(name)
192        self.health_check = None
193
194    def create_backend_service(
195            self,
196            protocol: Optional[BackendServiceProtocol] = _BackendGRPC,
197            subset_size: Optional[int] = None,
198            affinity_header: Optional[str] = None,
199            locality_lb_policies: Optional[List[dict]] = None,
200            outlier_detection: Optional[dict] = None):
201        if protocol is None:
202            protocol = _BackendGRPC
203
204        name = self.make_resource_name(self.BACKEND_SERVICE_NAME)
205        logger.info('Creating %s Backend Service "%s"', protocol.name, name)
206        resource = self.compute.create_backend_service_traffic_director(
207            name,
208            health_check=self.health_check,
209            protocol=protocol,
210            subset_size=subset_size,
211            affinity_header=affinity_header,
212            locality_lb_policies=locality_lb_policies,
213            outlier_detection=outlier_detection)
214        self.backend_service = resource
215        self.backend_service_protocol = protocol
216
217    def load_backend_service(self):
218        name = self.make_resource_name(self.BACKEND_SERVICE_NAME)
219        resource = self.compute.get_backend_service_traffic_director(name)
220        self.backend_service = resource
221
222    def delete_backend_service(self, force=False):
223        if force:
224            name = self.make_resource_name(self.BACKEND_SERVICE_NAME)
225        elif self.backend_service:
226            name = self.backend_service.name
227        else:
228            return
229        logger.info('Deleting Backend Service "%s"', name)
230        self.compute.delete_backend_service(name)
231        self.backend_service = None
232
233    def backend_service_add_neg_backends(self,
234                                         name,
235                                         zones,
236                                         max_rate_per_endpoint: Optional[
237                                             int] = None):
238        logger.info('Waiting for Network Endpoint Groups to load endpoints.')
239        for zone in zones:
240            backend = self.compute.wait_for_network_endpoint_group(name, zone)
241            logger.info('Loaded NEG "%s" in zone %s', backend.name,
242                        backend.zone)
243            self.backends.add(backend)
244        self.backend_service_patch_backends(max_rate_per_endpoint)
245
246    def backend_service_remove_neg_backends(self, name, zones):
247        logger.info('Waiting for Network Endpoint Groups to load endpoints.')
248        for zone in zones:
249            backend = self.compute.wait_for_network_endpoint_group(name, zone)
250            logger.info('Loaded NEG "%s" in zone %s', backend.name,
251                        backend.zone)
252            self.backends.remove(backend)
253        self.backend_service_patch_backends()
254
255    def backend_service_patch_backends(
256            self, max_rate_per_endpoint: Optional[int] = None):
257        logging.info('Adding backends to Backend Service %s: %r',
258                     self.backend_service.name, self.backends)
259        self.compute.backend_service_patch_backends(self.backend_service,
260                                                    self.backends,
261                                                    max_rate_per_endpoint)
262
263    def backend_service_remove_all_backends(self):
264        logging.info('Removing backends from Backend Service %s',
265                     self.backend_service.name)
266        self.compute.backend_service_remove_all_backends(self.backend_service)
267
268    def wait_for_backends_healthy_status(self):
269        logger.debug(
270            "Waiting for Backend Service %s to report all backends healthy %r",
271            self.backend_service, self.backends)
272        self.compute.wait_for_backends_healthy_status(self.backend_service,
273                                                      self.backends)
274
275    def create_alternative_backend_service(
276            self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC):
277        if protocol is None:
278            protocol = _BackendGRPC
279        name = self.make_resource_name(self.ALTERNATIVE_BACKEND_SERVICE_NAME)
280        logger.info('Creating %s Alternative Backend Service "%s"',
281                    protocol.name, name)
282        resource = self.compute.create_backend_service_traffic_director(
283            name, health_check=self.health_check, protocol=protocol)
284        self.alternative_backend_service = resource
285        self.alternative_backend_service_protocol = protocol
286
287    def load_alternative_backend_service(self):
288        name = self.make_resource_name(self.ALTERNATIVE_BACKEND_SERVICE_NAME)
289        resource = self.compute.get_backend_service_traffic_director(name)
290        self.alternative_backend_service = resource
291
292    def delete_alternative_backend_service(self, force=False):
293        if force:
294            name = self.make_resource_name(
295                self.ALTERNATIVE_BACKEND_SERVICE_NAME)
296        elif self.alternative_backend_service:
297            name = self.alternative_backend_service.name
298        else:
299            return
300        logger.info('Deleting Alternative Backend Service "%s"', name)
301        self.compute.delete_backend_service(name)
302        self.alternative_backend_service = None
303
304    def alternative_backend_service_add_neg_backends(self, name, zones):
305        logger.info('Waiting for Network Endpoint Groups to load endpoints.')
306        for zone in zones:
307            backend = self.compute.wait_for_network_endpoint_group(name, zone)
308            logger.info('Loaded NEG "%s" in zone %s', backend.name,
309                        backend.zone)
310            self.alternative_backends.add(backend)
311        self.alternative_backend_service_patch_backends()
312
313    def alternative_backend_service_patch_backends(self):
314        logging.info('Adding backends to Backend Service %s: %r',
315                     self.alternative_backend_service.name,
316                     self.alternative_backends)
317        self.compute.backend_service_patch_backends(
318            self.alternative_backend_service, self.alternative_backends)
319
320    def alternative_backend_service_remove_all_backends(self):
321        logging.info('Removing backends from Backend Service %s',
322                     self.alternative_backend_service.name)
323        self.compute.backend_service_remove_all_backends(
324            self.alternative_backend_service)
325
326    def wait_for_alternative_backends_healthy_status(self):
327        logger.debug(
328            "Waiting for Backend Service %s to report all backends healthy %r",
329            self.alternative_backend_service, self.alternative_backends)
330        self.compute.wait_for_backends_healthy_status(
331            self.alternative_backend_service, self.alternative_backends)
332
333    def create_affinity_backend_service(
334            self, protocol: Optional[BackendServiceProtocol] = _BackendGRPC):
335        if protocol is None:
336            protocol = _BackendGRPC
337        name = self.make_resource_name(self.AFFINITY_BACKEND_SERVICE_NAME)
338        logger.info('Creating %s Affinity Backend Service "%s"', protocol.name,
339                    name)
340        resource = self.compute.create_backend_service_traffic_director(
341            name,
342            health_check=self.health_check,
343            protocol=protocol,
344            affinity_header=TEST_AFFINITY_METADATA_KEY)
345        self.affinity_backend_service = resource
346        self.affinity_backend_service_protocol = protocol
347
348    def load_affinity_backend_service(self):
349        name = self.make_resource_name(self.AFFINITY_BACKEND_SERVICE_NAME)
350        resource = self.compute.get_backend_service_traffic_director(name)
351        self.affinity_backend_service = resource
352
353    def delete_affinity_backend_service(self, force=False):
354        if force:
355            name = self.make_resource_name(self.AFFINITY_BACKEND_SERVICE_NAME)
356        elif self.affinity_backend_service:
357            name = self.affinity_backend_service.name
358        else:
359            return
360        logger.info('Deleting Affinity Backend Service "%s"', name)
361        self.compute.delete_backend_service(name)
362        self.affinity_backend_service = None
363
364    def affinity_backend_service_add_neg_backends(self, name, zones):
365        logger.info('Waiting for Network Endpoint Groups to load endpoints.')
366        for zone in zones:
367            backend = self.compute.wait_for_network_endpoint_group(name, zone)
368            logger.info('Loaded NEG "%s" in zone %s', backend.name,
369                        backend.zone)
370            self.affinity_backends.add(backend)
371        self.affinity_backend_service_patch_backends()
372
373    def affinity_backend_service_patch_backends(self):
374        logging.info('Adding backends to Backend Service %s: %r',
375                     self.affinity_backend_service.name, self.affinity_backends)
376        self.compute.backend_service_patch_backends(
377            self.affinity_backend_service, self.affinity_backends)
378
379    def affinity_backend_service_remove_all_backends(self):
380        logging.info('Removing backends from Backend Service %s',
381                     self.affinity_backend_service.name)
382        self.compute.backend_service_remove_all_backends(
383            self.affinity_backend_service)
384
385    def wait_for_affinity_backends_healthy_status(self):
386        logger.debug(
387            "Waiting for Backend Service %s to report all backends healthy %r",
388            self.affinity_backend_service, self.affinity_backends)
389        self.compute.wait_for_backends_healthy_status(
390            self.affinity_backend_service, self.affinity_backends)
391
392    @staticmethod
393    def _generate_url_map_body(
394        name: str,
395        matcher_name: str,
396        src_hosts,
397        dst_default_backend_service: GcpResource,
398        dst_host_rule_match_backend_service: Optional[GcpResource] = None,
399    ) -> Dict[str, Any]:
400        if dst_host_rule_match_backend_service is None:
401            dst_host_rule_match_backend_service = dst_default_backend_service
402        return {
403            'name':
404                name,
405            'defaultService':
406                dst_default_backend_service.url,
407            'hostRules': [{
408                'hosts': src_hosts,
409                'pathMatcher': matcher_name,
410            }],
411            'pathMatchers': [{
412                'name': matcher_name,
413                'defaultService': dst_host_rule_match_backend_service.url,
414            }],
415        }
416
417    def create_url_map(self, src_host: str, src_port: int) -> GcpResource:
418        src_address = f'{src_host}:{src_port}'
419        name = self.make_resource_name(self.URL_MAP_NAME)
420        matcher_name = self.make_resource_name(self.URL_MAP_PATH_MATCHER_NAME)
421        logger.info('Creating URL map "%s": %s -> %s', name, src_address,
422                    self.backend_service.name)
423        resource = self.compute.create_url_map_with_content(
424            self._generate_url_map_body(name, matcher_name, [src_address],
425                                        self.backend_service))
426        self.url_map = resource
427        return resource
428
429    def patch_url_map(self, src_host: str, src_port: int,
430                      backend_service: GcpResource):
431        src_address = f'{src_host}:{src_port}'
432        name = self.make_resource_name(self.URL_MAP_NAME)
433        matcher_name = self.make_resource_name(self.URL_MAP_PATH_MATCHER_NAME)
434        logger.info('Patching URL map "%s": %s -> %s', name, src_address,
435                    backend_service.name)
436        self.compute.patch_url_map(
437            self.url_map,
438            self._generate_url_map_body(name, matcher_name, [src_address],
439                                        backend_service))
440
441    def create_url_map_with_content(self, url_map_body: Any) -> GcpResource:
442        logger.info('Creating URL map: %s', url_map_body)
443        resource = self.compute.create_url_map_with_content(url_map_body)
444        self.url_map = resource
445        return resource
446
447    def delete_url_map(self, force=False):
448        if force:
449            name = self.make_resource_name(self.URL_MAP_NAME)
450        elif self.url_map:
451            name = self.url_map.name
452        else:
453            return
454        logger.info('Deleting URL Map "%s"', name)
455        self.compute.delete_url_map(name)
456        self.url_map = None
457
458    def create_alternative_url_map(
459            self,
460            src_host: str,
461            src_port: int,
462            backend_service: Optional[GcpResource] = None) -> GcpResource:
463        name = self.make_resource_name(self.ALTERNATIVE_URL_MAP_NAME)
464        src_address = f'{src_host}:{src_port}'
465        matcher_name = self.make_resource_name(self.URL_MAP_PATH_MATCHER_NAME)
466        if backend_service is None:
467            backend_service = self.alternative_backend_service
468        logger.info('Creating alternative URL map "%s": %s -> %s', name,
469                    src_address, backend_service.name)
470        resource = self.compute.create_url_map_with_content(
471            self._generate_url_map_body(name, matcher_name, [src_address],
472                                        backend_service))
473        self.alternative_url_map = resource
474        return resource
475
476    def delete_alternative_url_map(self, force=False):
477        if force:
478            name = self.make_resource_name(self.ALTERNATIVE_URL_MAP_NAME)
479        elif self.alternative_url_map:
480            name = self.alternative_url_map.name
481        else:
482            return
483        logger.info('Deleting alternative URL Map "%s"', name)
484        self.compute.delete_url_map(name)
485        self.url_map = None
486
487    def create_target_proxy(self):
488        name = self.make_resource_name(self.TARGET_PROXY_NAME)
489        if self.backend_service_protocol is BackendServiceProtocol.GRPC:
490            target_proxy_type = 'GRPC'
491            create_proxy_fn = self.compute.create_target_grpc_proxy
492            self.target_proxy_is_http = False
493        elif self.backend_service_protocol is BackendServiceProtocol.HTTP2:
494            target_proxy_type = 'HTTP'
495            create_proxy_fn = self.compute.create_target_http_proxy
496            self.target_proxy_is_http = True
497        else:
498            raise TypeError('Unexpected backend service protocol')
499
500        logger.info('Creating target %s proxy "%s" to URL map %s', name,
501                    target_proxy_type, self.url_map.name)
502        self.target_proxy = create_proxy_fn(name, self.url_map)
503
504    def delete_target_grpc_proxy(self, force=False):
505        if force:
506            name = self.make_resource_name(self.TARGET_PROXY_NAME)
507        elif self.target_proxy:
508            name = self.target_proxy.name
509        else:
510            return
511        logger.info('Deleting Target GRPC proxy "%s"', name)
512        self.compute.delete_target_grpc_proxy(name)
513        self.target_proxy = None
514        self.target_proxy_is_http = False
515
516    def delete_target_http_proxy(self, force=False):
517        if force:
518            name = self.make_resource_name(self.TARGET_PROXY_NAME)
519        elif self.target_proxy and self.target_proxy_is_http:
520            name = self.target_proxy.name
521        else:
522            return
523        logger.info('Deleting HTTP Target proxy "%s"', name)
524        self.compute.delete_target_http_proxy(name)
525        self.target_proxy = None
526        self.target_proxy_is_http = False
527
528    def create_alternative_target_proxy(self):
529        name = self.make_resource_name(self.ALTERNATIVE_TARGET_PROXY_NAME)
530        if self.backend_service_protocol is BackendServiceProtocol.GRPC:
531            logger.info(
532                'Creating alternative target GRPC proxy "%s" to URL map %s',
533                name, self.alternative_url_map.name)
534            self.alternative_target_proxy = self.compute.create_target_grpc_proxy(
535                name, self.alternative_url_map, False)
536        else:
537            raise TypeError('Unexpected backend service protocol')
538
539    def delete_alternative_target_grpc_proxy(self, force=False):
540        if force:
541            name = self.make_resource_name(self.ALTERNATIVE_TARGET_PROXY_NAME)
542        elif self.alternative_target_proxy:
543            name = self.alternative_target_proxy.name
544        else:
545            return
546        logger.info('Deleting alternative Target GRPC proxy "%s"', name)
547        self.compute.delete_target_grpc_proxy(name)
548        self.alternative_target_proxy = None
549
550    def find_unused_forwarding_rule_port(
551            self,
552            *,
553            lo: int = 1024,  # To avoid confusion, skip well-known ports.
554            hi: int = 65535,
555            attempts: int = 25) -> int:
556        for _ in range(attempts):
557            src_port = random.randint(lo, hi)
558            if not self.compute.exists_forwarding_rule(src_port):
559                return src_port
560        # TODO(sergiitk): custom exception
561        raise RuntimeError("Couldn't find unused forwarding rule port")
562
563    def create_forwarding_rule(self, src_port: int):
564        name = self.make_resource_name(self.FORWARDING_RULE_NAME)
565        src_port = int(src_port)
566        logging.info(
567            'Creating forwarding rule "%s" in network "%s": 0.0.0.0:%s -> %s',
568            name, self.network, src_port, self.target_proxy.url)
569        resource = self.compute.create_forwarding_rule(name, src_port,
570                                                       self.target_proxy,
571                                                       self.network_url)
572        self.forwarding_rule = resource
573        return resource
574
575    def delete_forwarding_rule(self, force=False):
576        if force:
577            name = self.make_resource_name(self.FORWARDING_RULE_NAME)
578        elif self.forwarding_rule:
579            name = self.forwarding_rule.name
580        else:
581            return
582        logger.info('Deleting Forwarding rule "%s"', name)
583        self.compute.delete_forwarding_rule(name)
584        self.forwarding_rule = None
585
586    def create_alternative_forwarding_rule(self,
587                                           src_port: int,
588                                           ip_address='0.0.0.0'):
589        name = self.make_resource_name(self.ALTERNATIVE_FORWARDING_RULE_NAME)
590        src_port = int(src_port)
591        logging.info(
592            'Creating alternative forwarding rule "%s" in network "%s": %s:%s -> %s',
593            name, self.network, ip_address, src_port,
594            self.alternative_target_proxy.url)
595        resource = self.compute.create_forwarding_rule(
596            name,
597            src_port,
598            self.alternative_target_proxy,
599            self.network_url,
600            ip_address=ip_address)
601        self.alternative_forwarding_rule = resource
602        return resource
603
604    def delete_alternative_forwarding_rule(self, force=False):
605        if force:
606            name = self.make_resource_name(
607                self.ALTERNATIVE_FORWARDING_RULE_NAME)
608        elif self.alternative_forwarding_rule:
609            name = self.alternative_forwarding_rule.name
610        else:
611            return
612        logger.info('Deleting alternative Forwarding rule "%s"', name)
613        self.compute.delete_forwarding_rule(name)
614        self.alternative_forwarding_rule = None
615
616    def create_firewall_rule(self, allowed_ports: List[str]):
617        name = self.make_resource_name(self.FIREWALL_RULE_NAME)
618        logging.info(
619            'Creating firewall rule "%s" in network "%s" with allowed ports %s',
620            name, self.network, allowed_ports)
621        resource = self.compute.create_firewall_rule(
622            name, self.network_url, xds_flags.FIREWALL_SOURCE_RANGE.value,
623            allowed_ports)
624        self.firewall_rule = resource
625
626    def delete_firewall_rule(self, force=False):
627        """The firewall rule won't be automatically removed."""
628        if force:
629            name = self.make_resource_name(self.FIREWALL_RULE_NAME)
630        elif self.firewall_rule:
631            name = self.firewall_rule.name
632        else:
633            return
634        logger.info('Deleting Firewall Rule "%s"', name)
635        self.compute.delete_firewall_rule(name)
636        self.firewall_rule = None
637
638
639class TrafficDirectorAppNetManager(TrafficDirectorManager):
640
641    GRPC_ROUTE_NAME = "grpc-route"
642    MESH_NAME = "mesh"
643
644    netsvc: _NetworkServicesV1Alpha1
645
646    def __init__(self,
647                 gcp_api_manager: gcp.api.GcpApiManager,
648                 project: str,
649                 *,
650                 resource_prefix: str,
651                 resource_suffix: Optional[str] = None,
652                 network: str = 'default',
653                 compute_api_version: str = 'v1'):
654        super().__init__(gcp_api_manager,
655                         project,
656                         resource_prefix=resource_prefix,
657                         resource_suffix=resource_suffix,
658                         network=network,
659                         compute_api_version=compute_api_version)
660
661        # API
662        self.netsvc = _NetworkServicesV1Alpha1(gcp_api_manager, project)
663
664        # Managed resources
665        # TODO(gnossen) PTAL at the pylint error
666        self.grpc_route: Optional[GrpcRoute] = None
667        self.mesh: Optional[Mesh] = None
668
669    def create_mesh(self) -> GcpResource:
670        name = self.make_resource_name(self.MESH_NAME)
671        logger.info("Creating Mesh %s", name)
672        body = {}
673        resource = self.netsvc.create_mesh(name, body)
674        self.mesh = self.netsvc.get_mesh(name)
675        logger.debug("Loaded Mesh: %s", self.mesh)
676        return resource
677
678    def delete_mesh(self, force=False):
679        if force:
680            name = self.make_resource_name(self.MESH_NAME)
681        elif self.mesh:
682            name = self.mesh.name
683        else:
684            return
685        logger.info('Deleting Mesh %s', name)
686        self.netsvc.delete_mesh(name)
687        self.mesh = None
688
689    def create_grpc_route(self, src_host: str, src_port: int) -> GcpResource:
690        host = f'{src_host}:{src_port}'
691        service_name = self.netsvc.resource_full_name(self.backend_service.name,
692                                                      "backendServices")
693        body = {
694            "meshes": [self.mesh.url],
695            "hostnames":
696                host,
697            "rules": [{
698                "action": {
699                    "destinations": [{
700                        "serviceName": service_name
701                    }]
702                }
703            }],
704        }
705        name = self.make_resource_name(self.GRPC_ROUTE_NAME)
706        logger.info("Creating GrpcRoute %s", name)
707        resource = self.netsvc.create_grpc_route(name, body)
708        self.grpc_route = self.netsvc.get_grpc_route(name)
709        logger.debug("Loaded GrpcRoute: %s", self.grpc_route)
710        return resource
711
712    def create_grpc_route_with_content(self, body: Any) -> GcpResource:
713        name = self.make_resource_name(self.GRPC_ROUTE_NAME)
714        logger.info("Creating GrpcRoute %s", name)
715        resource = self.netsvc.create_grpc_route(name, body)
716        self.grpc_route = self.netsvc.get_grpc_route(name)
717        logger.debug("Loaded GrpcRoute: %s", self.grpc_route)
718        return resource
719
720    def delete_grpc_route(self, force=False):
721        if force:
722            name = self.make_resource_name(self.GRPC_ROUTE_NAME)
723        elif self.grpc_route:
724            name = self.grpc_route.name
725        else:
726            return
727        logger.info('Deleting GrpcRoute %s', name)
728        self.netsvc.delete_grpc_route(name)
729        self.grpc_route = None
730
731    def cleanup(self, *, force=False):
732        self.delete_grpc_route(force=force)
733        self.delete_mesh(force=force)
734        super().cleanup(force=force)
735
736
737class TrafficDirectorSecureManager(TrafficDirectorManager):
738    SERVER_TLS_POLICY_NAME = "server-tls-policy"
739    CLIENT_TLS_POLICY_NAME = "client-tls-policy"
740    AUTHZ_POLICY_NAME = "authz-policy"
741    ENDPOINT_POLICY = "endpoint-policy"
742    CERTIFICATE_PROVIDER_INSTANCE = "google_cloud_private_spiffe"
743
744    netsec: _NetworkSecurityV1Beta1
745    netsvc: _NetworkServicesV1Beta1
746
747    def __init__(
748        self,
749        gcp_api_manager: gcp.api.GcpApiManager,
750        project: str,
751        *,
752        resource_prefix: str,
753        resource_suffix: Optional[str] = None,
754        network: str = 'default',
755        compute_api_version: str = 'v1',
756    ):
757        super().__init__(gcp_api_manager,
758                         project,
759                         resource_prefix=resource_prefix,
760                         resource_suffix=resource_suffix,
761                         network=network,
762                         compute_api_version=compute_api_version)
763
764        # API
765        self.netsec = _NetworkSecurityV1Beta1(gcp_api_manager, project)
766        self.netsvc = _NetworkServicesV1Beta1(gcp_api_manager, project)
767
768        # Managed resources
769        self.server_tls_policy: Optional[ServerTlsPolicy] = None
770        self.client_tls_policy: Optional[ClientTlsPolicy] = None
771        self.authz_policy: Optional[AuthorizationPolicy] = None
772        self.endpoint_policy: Optional[EndpointPolicy] = None
773
774    def setup_server_security(self,
775                              *,
776                              server_namespace,
777                              server_name,
778                              server_port,
779                              tls=True,
780                              mtls=True):
781        self.create_server_tls_policy(tls=tls, mtls=mtls)
782        self.create_endpoint_policy(server_namespace=server_namespace,
783                                    server_name=server_name,
784                                    server_port=server_port)
785
786    def setup_client_security(self,
787                              *,
788                              server_namespace,
789                              server_name,
790                              tls=True,
791                              mtls=True):
792        self.create_client_tls_policy(tls=tls, mtls=mtls)
793        self.backend_service_apply_client_mtls_policy(server_namespace,
794                                                      server_name)
795
796    def cleanup(self, *, force=False):
797        # Cleanup in the reverse order of creation
798        super().cleanup(force=force)
799        self.delete_endpoint_policy(force=force)
800        self.delete_server_tls_policy(force=force)
801        self.delete_client_tls_policy(force=force)
802        self.delete_authz_policy(force=force)
803
804    def create_server_tls_policy(self, *, tls, mtls):
805        name = self.make_resource_name(self.SERVER_TLS_POLICY_NAME)
806        logger.info('Creating Server TLS Policy %s', name)
807        if not tls and not mtls:
808            logger.warning(
809                'Server TLS Policy %s neither TLS, nor mTLS '
810                'policy. Skipping creation', name)
811            return
812
813        certificate_provider = self._get_certificate_provider()
814        policy = {}
815        if tls:
816            policy["serverCertificate"] = certificate_provider
817        if mtls:
818            policy["mtlsPolicy"] = {
819                "clientValidationCa": [certificate_provider],
820            }
821
822        self.netsec.create_server_tls_policy(name, policy)
823        self.server_tls_policy = self.netsec.get_server_tls_policy(name)
824        logger.debug('Server TLS Policy loaded: %r', self.server_tls_policy)
825
826    def delete_server_tls_policy(self, force=False):
827        if force:
828            name = self.make_resource_name(self.SERVER_TLS_POLICY_NAME)
829        elif self.server_tls_policy:
830            name = self.server_tls_policy.name
831        else:
832            return
833        logger.info('Deleting Server TLS Policy %s', name)
834        self.netsec.delete_server_tls_policy(name)
835        self.server_tls_policy = None
836
837    def create_authz_policy(self, *, action: str, rules: list):
838        name = self.make_resource_name(self.AUTHZ_POLICY_NAME)
839        logger.info('Creating Authz Policy %s', name)
840        policy = {
841            "action": action,
842            "rules": rules,
843        }
844
845        self.netsec.create_authz_policy(name, policy)
846        self.authz_policy = self.netsec.get_authz_policy(name)
847        logger.debug('Authz Policy loaded: %r', self.authz_policy)
848
849    def delete_authz_policy(self, force=False):
850        if force:
851            name = self.make_resource_name(self.AUTHZ_POLICY_NAME)
852        elif self.authz_policy:
853            name = self.authz_policy.name
854        else:
855            return
856        logger.info('Deleting Authz Policy %s', name)
857        self.netsec.delete_authz_policy(name)
858        self.authz_policy = None
859
860    def create_endpoint_policy(self, *, server_namespace: str, server_name: str,
861                               server_port: int) -> None:
862        name = self.make_resource_name(self.ENDPOINT_POLICY)
863        logger.info('Creating Endpoint Policy %s', name)
864        endpoint_matcher_labels = [{
865            "labelName": "app",
866            "labelValue": f"{server_namespace}-{server_name}"
867        }]
868        port_selector = {"ports": [str(server_port)]}
869        label_matcher_all = {
870            "metadataLabelMatchCriteria": "MATCH_ALL",
871            "metadataLabels": endpoint_matcher_labels,
872        }
873        config = {
874            "type": "GRPC_SERVER",
875            "trafficPortSelector": port_selector,
876            "endpointMatcher": {
877                "metadataLabelMatcher": label_matcher_all,
878            },
879        }
880        if self.server_tls_policy:
881            config["serverTlsPolicy"] = self.server_tls_policy.name
882        else:
883            logger.warning(
884                'Creating Endpoint Policy %s with '
885                'no Server TLS policy attached', name)
886        if self.authz_policy:
887            config["authorizationPolicy"] = self.authz_policy.name
888
889        self.netsvc.create_endpoint_policy(name, config)
890        self.endpoint_policy = self.netsvc.get_endpoint_policy(name)
891        logger.debug('Loaded Endpoint Policy: %r', self.endpoint_policy)
892
893    def delete_endpoint_policy(self, force: bool = False) -> None:
894        if force:
895            name = self.make_resource_name(self.ENDPOINT_POLICY)
896        elif self.endpoint_policy:
897            name = self.endpoint_policy.name
898        else:
899            return
900        logger.info('Deleting Endpoint Policy %s', name)
901        self.netsvc.delete_endpoint_policy(name)
902        self.endpoint_policy = None
903
904    def create_client_tls_policy(self, *, tls, mtls):
905        name = self.make_resource_name(self.CLIENT_TLS_POLICY_NAME)
906        logger.info('Creating Client TLS Policy %s', name)
907        if not tls and not mtls:
908            logger.warning(
909                'Client TLS Policy %s neither TLS, nor mTLS '
910                'policy. Skipping creation', name)
911            return
912
913        certificate_provider = self._get_certificate_provider()
914        policy = {}
915        if tls:
916            policy["serverValidationCa"] = [certificate_provider]
917        if mtls:
918            policy["clientCertificate"] = certificate_provider
919
920        self.netsec.create_client_tls_policy(name, policy)
921        self.client_tls_policy = self.netsec.get_client_tls_policy(name)
922        logger.debug('Client TLS Policy loaded: %r', self.client_tls_policy)
923
924    def delete_client_tls_policy(self, force=False):
925        if force:
926            name = self.make_resource_name(self.CLIENT_TLS_POLICY_NAME)
927        elif self.client_tls_policy:
928            name = self.client_tls_policy.name
929        else:
930            return
931        logger.info('Deleting Client TLS Policy %s', name)
932        self.netsec.delete_client_tls_policy(name)
933        self.client_tls_policy = None
934
935    def backend_service_apply_client_mtls_policy(
936        self,
937        server_namespace,
938        server_name,
939    ):
940        if not self.client_tls_policy:
941            logger.warning(
942                'Client TLS policy not created, '
943                'skipping attaching to Backend Service %s',
944                self.backend_service.name)
945            return
946
947        server_spiffe = (f'spiffe://{self.project}.svc.id.goog/'
948                         f'ns/{server_namespace}/sa/{server_name}')
949        logging.info(
950            'Adding Client TLS Policy to Backend Service %s: %s, '
951            'server %s', self.backend_service.name, self.client_tls_policy.url,
952            server_spiffe)
953
954        self.compute.patch_backend_service(
955            self.backend_service, {
956                'securitySettings': {
957                    'clientTlsPolicy': self.client_tls_policy.url,
958                    'subjectAltNames': [server_spiffe]
959                }
960            })
961
962    @classmethod
963    def _get_certificate_provider(cls):
964        return {
965            "certificateProviderInstance": {
966                "pluginInstance": cls.CERTIFICATE_PROVIDER_INSTANCE,
967            },
968        }
969