xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio/grpc/_auth.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""GRPCAuthMetadataPlugins for standard authentication."""
15
16import inspect
17from typing import Any, Optional
18
19import grpc
20
21
22def _sign_request(
23    callback: grpc.AuthMetadataPluginCallback,
24    token: Optional[str],
25    error: Optional[Exception],
26):
27    metadata = (("authorization", "Bearer {}".format(token)),)
28    callback(metadata, error)
29
30
31class GoogleCallCredentials(grpc.AuthMetadataPlugin):
32    """Metadata wrapper for GoogleCredentials from the oauth2client library."""
33
34    _is_jwt: bool
35    _credentials: Any
36
37    # TODO(xuanwn): Give credentials an actual type.
38    def __init__(self, credentials: Any):
39        self._credentials = credentials
40        # Hack to determine if these are JWT creds and we need to pass
41        # additional_claims when getting a token
42        self._is_jwt = (
43            "additional_claims"
44            in inspect.getfullargspec(credentials.get_access_token).args
45        )
46
47    def __call__(
48        self,
49        context: grpc.AuthMetadataContext,
50        callback: grpc.AuthMetadataPluginCallback,
51    ):
52        try:
53            if self._is_jwt:
54                access_token = self._credentials.get_access_token(
55                    additional_claims={
56                        "aud": context.service_url  # pytype: disable=attribute-error
57                    }
58                ).access_token
59            else:
60                access_token = self._credentials.get_access_token().access_token
61        except Exception as exception:  # pylint: disable=broad-except
62            _sign_request(callback, None, exception)
63        else:
64            _sign_request(callback, access_token, None)
65
66
67class AccessTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
68    """Metadata wrapper for raw access token credentials."""
69
70    _access_token: str
71
72    def __init__(self, access_token: str):
73        self._access_token = access_token
74
75    def __call__(
76        self,
77        context: grpc.AuthMetadataContext,
78        callback: grpc.AuthMetadataPluginCallback,
79    ):
80        _sign_request(callback, self._access_token, None)
81