1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import abc
15import contextlib
16import functools
17import json
18import logging
19from typing import Any, Dict, List, Optional
20
21from absl import flags
22from google.cloud import secretmanager_v1
23from google.longrunning import operations_pb2
24from google.protobuf import json_format
25from google.rpc import code_pb2
26from google.rpc import error_details_pb2
27from google.rpc import status_pb2
28from googleapiclient import discovery
29import googleapiclient.errors
30import googleapiclient.http
31import tenacity
32import yaml
33
34import framework.helpers.highlighter
35
36logger = logging.getLogger(__name__)
37PRIVATE_API_KEY_SECRET_NAME = flags.DEFINE_string(
38    "private_api_key_secret_name",
39    default=None,
40    help="Load Private API access key from the latest version of the secret "
41    "with the given name, in the format projects/*/secrets/*")
42V1_DISCOVERY_URI = flags.DEFINE_string("v1_discovery_uri",
43                                       default=discovery.V1_DISCOVERY_URI,
44                                       help="Override v1 Discovery URI")
45V2_DISCOVERY_URI = flags.DEFINE_string("v2_discovery_uri",
46                                       default=discovery.V2_DISCOVERY_URI,
47                                       help="Override v2 Discovery URI")
48COMPUTE_V1_DISCOVERY_FILE = flags.DEFINE_string(
49    "compute_v1_discovery_file",
50    default=None,
51    help="Load compute v1 from discovery file")
52GCP_UI_URL = flags.DEFINE_string("gcp_ui_url",
53                                 default="console.cloud.google.com",
54                                 help="Override GCP UI URL.")
55
56# Type aliases
57_HttpError = googleapiclient.errors.HttpError
58_HttpLib2Error = googleapiclient.http.httplib2.HttpLib2Error
59_HighlighterYaml = framework.helpers.highlighter.HighlighterYaml
60Operation = operations_pb2.Operation
61HttpRequest = googleapiclient.http.HttpRequest
62
63
64class GcpApiManager:
65
66    def __init__(self,
67                 *,
68                 v1_discovery_uri=None,
69                 v2_discovery_uri=None,
70                 compute_v1_discovery_file=None,
71                 private_api_key_secret_name=None,
72                 gcp_ui_url=None):
73        self.v1_discovery_uri = v1_discovery_uri or V1_DISCOVERY_URI.value
74        self.v2_discovery_uri = v2_discovery_uri or V2_DISCOVERY_URI.value
75        self.compute_v1_discovery_file = (compute_v1_discovery_file or
76                                          COMPUTE_V1_DISCOVERY_FILE.value)
77        self.private_api_key_secret_name = (private_api_key_secret_name or
78                                            PRIVATE_API_KEY_SECRET_NAME.value)
79        self.gcp_ui_url = gcp_ui_url or GCP_UI_URL.value
80        # TODO(sergiitk): add options to pass google Credentials
81        self._exit_stack = contextlib.ExitStack()
82
83    def close(self):
84        self._exit_stack.close()
85
86    @property
87    @functools.lru_cache(None)
88    def private_api_key(self):
89        """
90        Private API key.
91
92        Return API key credential that identifies a GCP project allow-listed for
93        accessing private API discovery documents.
94        https://console.cloud.google.com/apis/credentials
95
96        This method lazy-loads the content of the key from the Secret Manager.
97        https://console.cloud.google.com/security/secret-manager
98        """
99        if not self.private_api_key_secret_name:
100            raise ValueError('private_api_key_secret_name must be set to '
101                             'access private_api_key.')
102
103        secrets_api = self.secrets('v1')
104        version_resource_path = secrets_api.secret_version_path(
105            **secrets_api.parse_secret_path(self.private_api_key_secret_name),
106            secret_version='latest')
107        secret: secretmanager_v1.AccessSecretVersionResponse
108        secret = secrets_api.access_secret_version(name=version_resource_path)
109        return secret.payload.data.decode()
110
111    @functools.lru_cache(None)
112    def compute(self, version):
113        api_name = 'compute'
114        if version == 'v1':
115            if self.compute_v1_discovery_file:
116                return self._build_from_file(self.compute_v1_discovery_file)
117            else:
118                return self._build_from_discovery_v1(api_name, version)
119        elif version == 'v1alpha':
120            return self._build_from_discovery_v1(api_name, 'alpha')
121
122        raise NotImplementedError(f'Compute {version} not supported')
123
124    @functools.lru_cache(None)
125    def networksecurity(self, version):
126        api_name = 'networksecurity'
127        if version == 'v1alpha1':
128            return self._build_from_discovery_v2(
129                api_name,
130                version,
131                api_key=self.private_api_key,
132                visibility_labels=['NETWORKSECURITY_ALPHA'])
133        elif version == 'v1beta1':
134            return self._build_from_discovery_v2(api_name, version)
135
136        raise NotImplementedError(f'Network Security {version} not supported')
137
138    @functools.lru_cache(None)
139    def networkservices(self, version):
140        api_name = 'networkservices'
141        if version == 'v1alpha1':
142            return self._build_from_discovery_v2(
143                api_name,
144                version,
145                api_key=self.private_api_key,
146                visibility_labels=['NETWORKSERVICES_ALPHA'])
147        elif version == 'v1beta1':
148            return self._build_from_discovery_v2(api_name, version)
149
150        raise NotImplementedError(f'Network Services {version} not supported')
151
152    @staticmethod
153    @functools.lru_cache(None)
154    def secrets(version: str):
155        if version == 'v1':
156            return secretmanager_v1.SecretManagerServiceClient()
157
158        raise NotImplementedError(f'Secret Manager {version} not supported')
159
160    @functools.lru_cache(None)
161    def iam(self, version: str) -> discovery.Resource:
162        """Identity and Access Management (IAM) API.
163
164        https://cloud.google.com/iam/docs/reference/rest
165        https://googleapis.github.io/google-api-python-client/docs/dyn/iam_v1.html
166        """
167        api_name = 'iam'
168        if version == 'v1':
169            return self._build_from_discovery_v1(api_name, version)
170
171        raise NotImplementedError(
172            f'Identity and Access Management (IAM) {version} not supported')
173
174    def _build_from_discovery_v1(self, api_name, version):
175        api = discovery.build(api_name,
176                              version,
177                              cache_discovery=False,
178                              discoveryServiceUrl=self.v1_discovery_uri)
179        self._exit_stack.enter_context(api)
180        return api
181
182    def _build_from_discovery_v2(self,
183                                 api_name,
184                                 version,
185                                 *,
186                                 api_key: Optional[str] = None,
187                                 visibility_labels: Optional[List] = None):
188        params = {}
189        if api_key:
190            params['key'] = api_key
191        if visibility_labels:
192            # Dash-separated list of labels.
193            params['labels'] = '_'.join(visibility_labels)
194
195        params_str = ''
196        if params:
197            params_str = '&' + ('&'.join(f'{k}={v}' for k, v in params.items()))
198
199        api = discovery.build(
200            api_name,
201            version,
202            cache_discovery=False,
203            discoveryServiceUrl=f'{self.v2_discovery_uri}{params_str}')
204        self._exit_stack.enter_context(api)
205        return api
206
207    def _build_from_file(self, discovery_file):
208        with open(discovery_file, 'r') as f:
209            api = discovery.build_from_document(f.read())
210        self._exit_stack.enter_context(api)
211        return api
212
213
214class Error(Exception):
215    """Base error class for GCP API errors."""
216
217
218class ResponseError(Error):
219    """The response was not a 2xx."""
220    reason: str
221    uri: str
222    error_details: Optional[str]
223    status: Optional[int]
224    cause: _HttpError
225
226    def __init__(self, cause: _HttpError):
227        # TODO(sergiitk): cleanup when we upgrade googleapiclient:
228        #  - remove _get_reason()
229        #  - remove error_details note
230        #  - use status_code()
231        self.reason = cause._get_reason().strip()  # noqa
232        self.uri = cause.uri
233        self.error_details = cause.error_details  # NOTE: Must after _get_reason
234        self.status = None
235        if cause.resp and cause.resp.status:
236            self.status = cause.resp.status
237        self.cause = cause
238        super().__init__()
239
240    def __repr__(self):
241        return (f'<ResponseError {self.status} when requesting {self.uri} '
242                f'returned "{self.reason}". Details: "{self.error_details}">')
243
244
245class TransportError(Error):
246    """A transport error has occurred."""
247    cause: _HttpLib2Error
248
249    def __init__(self, cause: _HttpLib2Error):
250        self.cause = cause
251        super().__init__()
252
253    def __repr__(self):
254        return f'<TransportError cause: {self.cause!r}>'
255
256
257class OperationError(Error):
258    """
259    Operation was not successful.
260
261    Assuming Operation based on Google API Style Guide:
262    https://cloud.google.com/apis/design/design_patterns#long_running_operations
263    https://github.com/googleapis/googleapis/blob/master/google/longrunning/operations.proto
264    """
265    api_name: str
266    name: str
267    metadata: Any
268    code_name: code_pb2.Code
269    error: status_pb2.Status
270
271    def __init__(self, api_name: str, response: dict):
272        self.api_name = api_name
273
274        # Operation.metadata field is Any specific to the API. It may not be
275        # present in the default descriptor pool, and that's expected.
276        # To avoid json_format.ParseError, handle it separately.
277        self.metadata = response.pop('metadata', {})
278
279        # Must be after removing metadata field.
280        operation: Operation = self._parse_operation_response(response)
281        self.name = operation.name or 'unknown'
282        self.code_name = code_pb2.Code.Name(operation.error.code)
283        self.error = operation.error
284        super().__init__()
285
286    @staticmethod
287    def _parse_operation_response(operation_response: dict) -> Operation:
288        try:
289            return json_format.ParseDict(
290                operation_response,
291                Operation(),
292                ignore_unknown_fields=True,
293                descriptor_pool=error_details_pb2.DESCRIPTOR.pool)
294        except (json_format.Error, TypeError) as e:
295            # Swallow parsing errors if any. Building correct OperationError()
296            # is more important than losing debug information. Details still
297            # can be extracted from the warning.
298            logger.warning(
299                ("Can't parse response while processing OperationError: '%r', "
300                 "error %r"), operation_response, e)
301            return Operation()
302
303    def __str__(self):
304        indent_l1 = ' ' * 2
305        indent_l2 = indent_l1 * 2
306
307        result = (f'{self.api_name} operation "{self.name}" failed.\n'
308                  f'{indent_l1}code: {self.error.code} ({self.code_name})\n'
309                  f'{indent_l1}message: "{self.error.message}"')
310
311        if self.error.details:
312            result += f'\n{indent_l1}details: [\n'
313            for any_error in self.error.details:
314                error_str = json_format.MessageToJson(any_error)
315                for line in error_str.splitlines():
316                    result += indent_l2 + line + '\n'
317            result += f'{indent_l1}]'
318
319        if self.metadata:
320            result += f'\n  metadata: \n'
321            metadata_str = json.dumps(self.metadata, indent=2)
322            for line in metadata_str.splitlines():
323                result += indent_l2 + line + '\n'
324            result = result.rstrip()
325
326        return result
327
328
329class GcpProjectApiResource:
330    # TODO(sergiitk): move someplace better
331    _WAIT_FOR_OPERATION_SEC = 60 * 10
332    _WAIT_FIXED_SEC = 2
333    _GCP_API_RETRIES = 5
334
335    def __init__(self, api: discovery.Resource, project: str):
336        self.api: discovery.Resource = api
337        self.project: str = project
338        self._highlighter = _HighlighterYaml()
339
340    # TODO(sergiitk): in upcoming GCP refactoring, differentiate between
341    #   _execute for LRO (Long Running Operations), and immediate operations.
342    def _execute(
343            self,
344            request: HttpRequest,
345            *,
346            num_retries: Optional[int] = _GCP_API_RETRIES) -> Dict[str, Any]:
347        """Execute the immediate request.
348
349        Returns:
350          Unmarshalled response as a dictionary.
351
352        Raises:
353          ResponseError if the response was not a 2xx.
354          TransportError if a transport error has occurred.
355        """
356        if num_retries is None:
357            num_retries = self._GCP_API_RETRIES
358        try:
359            return request.execute(num_retries=num_retries)
360        except _HttpError as error:
361            raise ResponseError(error)
362        except _HttpLib2Error as error:
363            raise TransportError(error)
364
365    def resource_pretty_format(self, body: dict) -> str:
366        """Return a string with pretty-printed resource body."""
367        yaml_out: str = yaml.dump(body, explicit_start=True, explicit_end=True)
368        return self._highlighter.highlight(yaml_out)
369
370    @staticmethod
371    def wait_for_operation(operation_request,
372                           test_success_fn,
373                           timeout_sec=_WAIT_FOR_OPERATION_SEC,
374                           wait_sec=_WAIT_FIXED_SEC):
375        retryer = tenacity.Retrying(
376            retry=(tenacity.retry_if_not_result(test_success_fn) |
377                   tenacity.retry_if_exception_type()),
378            wait=tenacity.wait_fixed(wait_sec),
379            stop=tenacity.stop_after_delay(timeout_sec),
380            after=tenacity.after_log(logger, logging.DEBUG),
381            reraise=True)
382        return retryer(operation_request.execute)
383
384
385class GcpStandardCloudApiResource(GcpProjectApiResource, metaclass=abc.ABCMeta):
386    GLOBAL_LOCATION = 'global'
387
388    def parent(self, location: Optional[str] = GLOBAL_LOCATION):
389        if location is None:
390            location = self.GLOBAL_LOCATION
391        return f'projects/{self.project}/locations/{location}'
392
393    def resource_full_name(self, name, collection_name):
394        return f'{self.parent()}/{collection_name}/{name}'
395
396    def _create_resource(self, collection: discovery.Resource, body: dict,
397                         **kwargs):
398        logger.info("Creating %s resource:\n%s", self.api_name,
399                    self.resource_pretty_format(body))
400        create_req = collection.create(parent=self.parent(),
401                                       body=body,
402                                       **kwargs)
403        self._execute(create_req)
404
405    @property
406    @abc.abstractmethod
407    def api_name(self) -> str:
408        raise NotImplementedError
409
410    @property
411    @abc.abstractmethod
412    def api_version(self) -> str:
413        raise NotImplementedError
414
415    def _get_resource(self, collection: discovery.Resource, full_name):
416        resource = collection.get(name=full_name).execute()
417        logger.info('Loaded %s:\n%s', full_name,
418                    self.resource_pretty_format(resource))
419        return resource
420
421    def _delete_resource(self, collection: discovery.Resource,
422                         full_name: str) -> bool:
423        logger.debug("Deleting %s", full_name)
424        try:
425            self._execute(collection.delete(name=full_name))
426            return True
427        except _HttpError as error:
428            if error.resp and error.resp.status == 404:
429                logger.info('%s not deleted since it does not exist', full_name)
430            else:
431                logger.warning('Failed to delete %s, %r', full_name, error)
432        return False
433
434    # TODO(sergiitk): Use ResponseError and TransportError
435    def _execute(  # pylint: disable=arguments-differ
436            self,
437            request: HttpRequest,
438            timeout_sec: int = GcpProjectApiResource._WAIT_FOR_OPERATION_SEC):
439        operation = request.execute(num_retries=self._GCP_API_RETRIES)
440        logger.debug('Operation %s', operation)
441        self._wait(operation['name'], timeout_sec)
442
443    def _wait(self,
444              operation_id: str,
445              timeout_sec: int = GcpProjectApiResource._WAIT_FOR_OPERATION_SEC):
446        logger.info('Waiting %s sec for %s operation id: %s', timeout_sec,
447                    self.api_name, operation_id)
448
449        op_request = self.api.projects().locations().operations().get(
450            name=operation_id)
451        operation = self.wait_for_operation(
452            operation_request=op_request,
453            test_success_fn=lambda result: result['done'],
454            timeout_sec=timeout_sec)
455
456        logger.debug('Completed operation: %s', operation)
457        if 'error' in operation:
458            raise OperationError(self.api_name, operation)
459