1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc
16import dataclasses
17import logging
18from typing import Any, Dict
19
20from google.rpc import code_pb2
21import tenacity
22
23from framework.infrastructure import gcp
24
25logger = logging.getLogger(__name__)
26
27# Type aliases
28GcpResource = gcp.compute.ComputeV1.GcpResource
29
30
31@dataclasses.dataclass(frozen=True)
32class ServerTlsPolicy:
33    url: str
34    name: str
35    server_certificate: dict
36    mtls_policy: dict
37    update_time: str
38    create_time: str
39
40    @classmethod
41    def from_response(cls, name: str, response: Dict[str,
42                                                     Any]) -> 'ServerTlsPolicy':
43        return cls(name=name,
44                   url=response['name'],
45                   server_certificate=response.get('serverCertificate', {}),
46                   mtls_policy=response.get('mtlsPolicy', {}),
47                   create_time=response['createTime'],
48                   update_time=response['updateTime'])
49
50
51@dataclasses.dataclass(frozen=True)
52class ClientTlsPolicy:
53    url: str
54    name: str
55    client_certificate: dict
56    server_validation_ca: list
57    update_time: str
58    create_time: str
59
60    @classmethod
61    def from_response(cls, name: str, response: Dict[str,
62                                                     Any]) -> 'ClientTlsPolicy':
63        return cls(name=name,
64                   url=response['name'],
65                   client_certificate=response.get('clientCertificate', {}),
66                   server_validation_ca=response.get('serverValidationCa', []),
67                   create_time=response['createTime'],
68                   update_time=response['updateTime'])
69
70
71@dataclasses.dataclass(frozen=True)
72class AuthorizationPolicy:
73    url: str
74    name: str
75    update_time: str
76    create_time: str
77    action: str
78    rules: list
79
80    @classmethod
81    def from_response(cls, name: str,
82                      response: Dict[str, Any]) -> 'AuthorizationPolicy':
83        return cls(name=name,
84                   url=response['name'],
85                   create_time=response['createTime'],
86                   update_time=response['updateTime'],
87                   action=response['action'],
88                   rules=response.get('rules', []))
89
90
91class _NetworkSecurityBase(gcp.api.GcpStandardCloudApiResource,
92                           metaclass=abc.ABCMeta):
93    """Base class for NetworkSecurity APIs."""
94
95    # TODO(https://github.com/grpc/grpc/issues/29532) remove pylint disable
96    # pylint: disable=abstract-method
97
98    def __init__(self, api_manager: gcp.api.GcpApiManager, project: str):
99        super().__init__(api_manager.networksecurity(self.api_version), project)
100        # Shortcut to projects/*/locations/ endpoints
101        self._api_locations = self.api.projects().locations()
102
103    @property
104    def api_name(self) -> str:
105        return 'networksecurity'
106
107    def _execute(self, *args, **kwargs):  # pylint: disable=signature-differs,arguments-differ
108        # Workaround TD bug: throttled operations are reported as internal.
109        # Ref b/175345578
110        retryer = tenacity.Retrying(
111            retry=tenacity.retry_if_exception(self._operation_internal_error),
112            wait=tenacity.wait_fixed(10),
113            stop=tenacity.stop_after_delay(5 * 60),
114            before_sleep=tenacity.before_sleep_log(logger, logging.DEBUG),
115            reraise=True)
116        retryer(super()._execute, *args, **kwargs)
117
118    @staticmethod
119    def _operation_internal_error(exception):
120        return (isinstance(exception, gcp.api.OperationError) and
121                exception.error.code == code_pb2.INTERNAL)
122
123
124class NetworkSecurityV1Beta1(_NetworkSecurityBase):
125    """NetworkSecurity API v1beta1."""
126
127    SERVER_TLS_POLICIES = 'serverTlsPolicies'
128    CLIENT_TLS_POLICIES = 'clientTlsPolicies'
129    AUTHZ_POLICIES = 'authorizationPolicies'
130
131    @property
132    def api_version(self) -> str:
133        return 'v1beta1'
134
135    def create_server_tls_policy(self, name: str, body: dict) -> GcpResource:
136        return self._create_resource(
137            collection=self._api_locations.serverTlsPolicies(),
138            body=body,
139            serverTlsPolicyId=name)
140
141    def get_server_tls_policy(self, name: str) -> ServerTlsPolicy:
142        response = self._get_resource(
143            collection=self._api_locations.serverTlsPolicies(),
144            full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES))
145        return ServerTlsPolicy.from_response(name, response)
146
147    def delete_server_tls_policy(self, name: str) -> bool:
148        return self._delete_resource(
149            collection=self._api_locations.serverTlsPolicies(),
150            full_name=self.resource_full_name(name, self.SERVER_TLS_POLICIES))
151
152    def create_client_tls_policy(self, name: str, body: dict) -> GcpResource:
153        return self._create_resource(
154            collection=self._api_locations.clientTlsPolicies(),
155            body=body,
156            clientTlsPolicyId=name)
157
158    def get_client_tls_policy(self, name: str) -> ClientTlsPolicy:
159        response = self._get_resource(
160            collection=self._api_locations.clientTlsPolicies(),
161            full_name=self.resource_full_name(name, self.CLIENT_TLS_POLICIES))
162        return ClientTlsPolicy.from_response(name, response)
163
164    def delete_client_tls_policy(self, name: str) -> bool:
165        return self._delete_resource(
166            collection=self._api_locations.clientTlsPolicies(),
167            full_name=self.resource_full_name(name, self.CLIENT_TLS_POLICIES))
168
169    def create_authz_policy(self, name: str, body: dict) -> GcpResource:
170        return self._create_resource(
171            collection=self._api_locations.authorizationPolicies(),
172            body=body,
173            authorizationPolicyId=name)
174
175    def get_authz_policy(self, name: str) -> ClientTlsPolicy:
176        response = self._get_resource(
177            collection=self._api_locations.authorizationPolicies(),
178            full_name=self.resource_full_name(name, self.AUTHZ_POLICIES))
179        return ClientTlsPolicy.from_response(name, response)
180
181    def delete_authz_policy(self, name: str) -> bool:
182        return self._delete_resource(
183            collection=self._api_locations.authorizationPolicies(),
184            full_name=self.resource_full_name(name, self.AUTHZ_POLICIES))
185
186
187class NetworkSecurityV1Alpha1(NetworkSecurityV1Beta1):
188    """NetworkSecurity API v1alpha1.
189
190    Note: extending v1beta1 class presumes that v1beta1 is just a v1alpha1 API
191    graduated into a more stable version. This is true in most cases. However,
192    v1alpha1 class can always override and reimplement incompatible methods.
193    """
194
195    @property
196    def api_version(self) -> str:
197        return 'v1alpha1'
198