xref: /aosp_15_r20/external/tink/python/tink/jwt/_jwt_signature_key_manager.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""JWT Signature key managers."""
15
16from typing import Any, Optional, Type, Tuple, Callable
17
18from tink.proto import jwt_ecdsa_pb2
19from tink.proto import jwt_rsa_ssa_pkcs1_pb2
20from tink.proto import jwt_rsa_ssa_pss_pb2
21from tink.proto import tink_pb2
22from tink import core
23from tink.cc.pybind import tink_bindings
24from tink.jwt import _json_util
25from tink.jwt import _jwt_error
26from tink.jwt import _jwt_format
27from tink.jwt import _jwt_public_key_sign
28from tink.jwt import _jwt_public_key_verify
29from tink.jwt import _jwt_validator
30from tink.jwt import _raw_jwt
31from tink.jwt import _verified_jwt
32
33_JWT_ECDSA_PRIVATE_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtEcdsaPrivateKey'
34_JWT_ECDSA_PUBLIC_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtEcdsaPublicKey'
35
36_JWT_RSA_SSA_PKCS1_PRIVATE_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PrivateKey'
37_JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PublicKey'
38
39_JWT_RSA_SSA_PSS_PRIVATE_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPrivateKey'
40_JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPublicKey'
41
42_ECDSA_ALGORITHM_TEXTS = {
43    jwt_ecdsa_pb2.ES256: 'ES256',
44    jwt_ecdsa_pb2.ES384: 'ES384',
45    jwt_ecdsa_pb2.ES512: 'ES512'
46}
47
48_RSA_SSA_PKCS1_ALGORITHM_TEXTS = {
49    jwt_rsa_ssa_pkcs1_pb2.RS256: 'RS256',
50    jwt_rsa_ssa_pkcs1_pb2.RS384: 'RS384',
51    jwt_rsa_ssa_pkcs1_pb2.RS512: 'RS512'
52}
53
54_RSA_SSA_PSS_ALGORITHM_TEXTS = {
55    jwt_rsa_ssa_pss_pb2.PS256: 'PS256',
56    jwt_rsa_ssa_pss_pb2.PS384: 'PS384',
57    jwt_rsa_ssa_pss_pb2.PS512: 'PS512'
58}
59
60
61class _JwtPublicKeySign(_jwt_public_key_sign.JwtPublicKeySignInternal):
62  """Implementation of JwtPublicKeySignInternal using a PublicKeySign."""
63
64  def __init__(self, cc_primitive: tink_bindings.PublicKeySign, algorithm: str,
65               custom_kid: str):
66    self._public_key_sign = cc_primitive
67    self._algorithm = algorithm
68    self._custom_kid = custom_kid
69
70  @core.use_tink_errors
71  def _sign(self, data: bytes) -> bytes:
72    return self._public_key_sign.sign(data)
73
74  def sign_and_encode_with_kid(self, raw_jwt: _raw_jwt.RawJwt,
75                               kid: Optional[str]) -> str:
76    """Computes a signature and encodes the token.
77
78    Args:
79      raw_jwt: The RawJwt token to be MACed and encoded.
80      kid: Optional "kid" header value. It is set by the wrapper for keys with
81        output prefix TINK, and it is None for output prefix RAW.
82
83    Returns:
84      The MACed token encoded in the JWS compact serialization format.
85    Raises:
86      tink.TinkError if the operation fails.
87    """
88    if self._custom_kid is not None:
89      if kid is not None:
90        raise _jwt_error.JwtInvalidError(
91            'custom_kid must not be set for keys with output prefix type TINK')
92      kid = self._custom_kid
93    unsigned = _jwt_format.create_unsigned_compact(self._algorithm, kid,
94                                                   raw_jwt)
95    return _jwt_format.create_signed_compact(unsigned, self._sign(unsigned))
96
97
98class _JwtPublicKeyVerify(_jwt_public_key_verify.JwtPublicKeyVerifyInternal):
99  """Implementation of JwtPublicKeyVerify using a PublicKeyVerify."""
100
101  def __init__(self, cc_primitive: tink_bindings.PublicKeyVerify,
102               algorithm: str, custom_kid: Optional[str]):
103    self._public_key_verify = cc_primitive
104    self._algorithm = algorithm
105    self._custom_kid = custom_kid
106
107  @core.use_tink_errors
108  def _verify(self, signature: bytes, data: bytes) -> None:
109    self._public_key_verify.verify(signature, data)
110
111  def verify_and_decode_with_kid(
112      self, compact: str, validator: _jwt_validator.JwtValidator,
113      kid: Optional[str]) -> _verified_jwt.VerifiedJwt:
114    """Verifies, validates and decodes a signed compact JWT token."""
115    parts = _jwt_format.split_signed_compact(compact)
116    unsigned_compact, json_header, json_payload, signature = parts
117    self._verify(signature, unsigned_compact)
118    header = _json_util.json_loads(json_header)
119    _jwt_format.validate_header(
120        header=header,
121        algorithm=self._algorithm,
122        tink_kid=kid,
123        custom_kid=self._custom_kid)
124    raw_jwt = _raw_jwt.raw_jwt_from_json(
125        _jwt_format.get_type_header(header), json_payload)
126    _jwt_validator.validate(validator, raw_jwt)
127    return _verified_jwt.VerifiedJwt._create(raw_jwt)  # pylint: disable=protected-access
128
129
130class _JwtPublicKeySignKeyManagerCcToPyWrapper(
131    core.PrivateKeyManager[_jwt_public_key_sign.JwtPublicKeySignInternal]):
132  """Converts a C++ sign key manager into a JwtPublicKeySignKeyManager."""
133
134  def __init__(self, cc_key_manager: tink_bindings.PublicKeySignKeyManager,
135               key_data_to_alg_kid: Callable[[tink_pb2.KeyData],
136                                             Tuple[str, Optional[str]]]):
137    self._cc_key_manager = cc_key_manager
138    self._key_data_to_alg_kid = key_data_to_alg_kid
139
140  def primitive_class(
141      self) -> Type[_jwt_public_key_sign.JwtPublicKeySignInternal]:
142    return _jwt_public_key_sign.JwtPublicKeySignInternal
143
144  @core.use_tink_errors
145  def primitive(
146      self, key_data: tink_pb2.KeyData
147  ) -> _jwt_public_key_sign.JwtPublicKeySignInternal:
148    sign = self._cc_key_manager.primitive(key_data.SerializeToString())
149    algorithm, custom_kid = self._key_data_to_alg_kid(key_data)
150    return _JwtPublicKeySign(sign, algorithm, custom_kid)
151
152  def key_type(self) -> str:
153    return self._cc_key_manager.key_type()
154
155  @core.use_tink_errors
156  def new_key_data(self,
157                   key_template: tink_pb2.KeyTemplate) -> tink_pb2.KeyData:
158    return tink_pb2.KeyData.FromString(
159        self._cc_key_manager.new_key_data(key_template.SerializeToString()))
160
161  @core.use_tink_errors
162  def public_key_data(self, key_data: tink_pb2.KeyData) -> tink_pb2.KeyData:
163    return tink_pb2.KeyData.FromString(
164        self._cc_key_manager.public_key_data(key_data.SerializeToString()))
165
166
167class _JwtPublicKeyVerifyKeyManagerCcToPyWrapper(
168    core.KeyManager[_jwt_public_key_verify.JwtPublicKeyVerifyInternal]):
169  """Converts a C++ verify key manager into a JwtPublicKeyVerifyKeyManager."""
170
171  def __init__(self, cc_key_manager: tink_bindings.PublicKeyVerifyKeyManager,
172               key_data_to_alg_kid: Callable[[tink_pb2.KeyData],
173                                             Tuple[str, Optional[str]]]):
174    self._cc_key_manager = cc_key_manager
175    self._key_data_to_alg_kid = key_data_to_alg_kid
176
177  def primitive_class(
178      self) -> Type[_jwt_public_key_verify.JwtPublicKeyVerifyInternal]:
179    return _jwt_public_key_verify.JwtPublicKeyVerifyInternal
180
181  @core.use_tink_errors
182  def primitive(
183      self, key_data: tink_pb2.KeyData
184  ) -> _jwt_public_key_verify.JwtPublicKeyVerifyInternal:
185    verify = self._cc_key_manager.primitive(key_data.SerializeToString())
186    algorithm, custom_kid = self._key_data_to_alg_kid(key_data)
187    return _JwtPublicKeyVerify(verify, algorithm, custom_kid)
188
189  def key_type(self) -> str:
190    return self._cc_key_manager.key_type()
191
192  @core.use_tink_errors
193  def new_key_data(self,
194                   key_template: tink_pb2.KeyTemplate) -> tink_pb2.KeyData:
195    return tink_pb2.KeyData.FromString(
196        self._cc_key_manager.new_key_data(key_template.SerializeToString()))
197
198
199def _ecdsa_algorithm_text(algorithm: jwt_ecdsa_pb2.JwtEcdsaAlgorithm) -> str:
200  if algorithm not in _ECDSA_ALGORITHM_TEXTS:
201    raise _jwt_error.JwtInvalidError('Invalid algorithm')
202  return _ECDSA_ALGORITHM_TEXTS[algorithm]
203
204
205def _get_custom_kid(public_key_proto: Any) -> Optional[str]:
206  if public_key_proto.HasField('custom_kid'):
207    return public_key_proto.custom_kid.value
208  else:
209    return None
210
211
212def _ecdsa_alg_kid_from_private_key_data(
213    key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]:
214  if key_data.type_url != _JWT_ECDSA_PRIVATE_KEY_TYPE:
215    raise _jwt_error.JwtInvalidError('Invalid key data key type')
216  key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(key_data.value)
217  return (_ecdsa_algorithm_text(key.public_key.algorithm),
218          _get_custom_kid(key.public_key))
219
220
221def _ecdsa_alg_kid_from_public_key_data(
222    key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]:
223  if key_data.type_url != _JWT_ECDSA_PUBLIC_KEY_TYPE:
224    raise _jwt_error.JwtInvalidError('Invalid key data key type')
225  key = jwt_ecdsa_pb2.JwtEcdsaPublicKey.FromString(key_data.value)
226  return (_ecdsa_algorithm_text(key.algorithm), _get_custom_kid(key))
227
228
229def _rsa_ssa_pkcs1_algorithm_text(
230    algorithm: jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1Algorithm) -> str:
231  if algorithm not in _RSA_SSA_PKCS1_ALGORITHM_TEXTS:
232    raise _jwt_error.JwtInvalidError('Invalid algorithm')
233  return _RSA_SSA_PKCS1_ALGORITHM_TEXTS[algorithm]
234
235
236def _rsa_ssa_pkcs1_alg_kid_from_private_key_data(
237    key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]:
238  if key_data.type_url != _JWT_RSA_SSA_PKCS1_PRIVATE_KEY_TYPE:
239    raise _jwt_error.JwtInvalidError('Invalid key data key type')
240  key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PrivateKey.FromString(
241      key_data.value)
242  return (_rsa_ssa_pkcs1_algorithm_text(key.public_key.algorithm),
243          _get_custom_kid(key.public_key))
244
245
246def _rsa_ssa_pkcs1_alg_kid_from_public_key_data(
247    key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]:
248  if key_data.type_url != _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE:
249    raise _jwt_error.JwtInvalidError('Invalid key data key type')
250  key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PublicKey.FromString(key_data.value)
251  return (_rsa_ssa_pkcs1_algorithm_text(key.algorithm), _get_custom_kid(key))
252
253
254def _rsa_ssa_pss_algorithm_text(
255    algorithm: jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssAlgorithm) -> str:
256  if algorithm not in _RSA_SSA_PSS_ALGORITHM_TEXTS:
257    raise _jwt_error.JwtInvalidError('Invalid algorithm')
258  return _RSA_SSA_PSS_ALGORITHM_TEXTS[algorithm]
259
260
261def _rsa_ssa_pss_alg_kid_from_private_key_data(
262    key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]:
263  if key_data.type_url != _JWT_RSA_SSA_PSS_PRIVATE_KEY_TYPE:
264    raise _jwt_error.JwtInvalidError('Invalid key data key type')
265  key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPrivateKey.FromString(key_data.value)
266  return (_rsa_ssa_pss_algorithm_text(key.public_key.algorithm),
267          _get_custom_kid(key.public_key))
268
269
270def _rsa_ssa_pss_alg_kid_from_public_key_data(
271    key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]:
272  if key_data.type_url != _JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE:
273    raise _jwt_error.JwtInvalidError('Invalid key data key type')
274  key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPublicKey.FromString(key_data.value)
275  return (_rsa_ssa_pss_algorithm_text(key.algorithm), _get_custom_kid(key))
276
277
278def register():
279  """Registers all JWT signature primitives."""
280  tink_bindings.register_jwt()
281
282  private_key_manager = _JwtPublicKeySignKeyManagerCcToPyWrapper(
283      tink_bindings.PublicKeySignKeyManager.from_cc_registry(
284          _JWT_ECDSA_PRIVATE_KEY_TYPE), _ecdsa_alg_kid_from_private_key_data)
285  core.Registry.register_key_manager(private_key_manager, new_key_allowed=True)
286
287  public_key_manager = _JwtPublicKeyVerifyKeyManagerCcToPyWrapper(
288      tink_bindings.PublicKeyVerifyKeyManager.from_cc_registry(
289          _JWT_ECDSA_PUBLIC_KEY_TYPE), _ecdsa_alg_kid_from_public_key_data)
290  core.Registry.register_key_manager(public_key_manager, new_key_allowed=True)
291
292  private_key_manager = _JwtPublicKeySignKeyManagerCcToPyWrapper(
293      tink_bindings.PublicKeySignKeyManager.from_cc_registry(
294          _JWT_RSA_SSA_PKCS1_PRIVATE_KEY_TYPE),
295      _rsa_ssa_pkcs1_alg_kid_from_private_key_data)
296  core.Registry.register_key_manager(private_key_manager, new_key_allowed=True)
297
298  public_key_manager = _JwtPublicKeyVerifyKeyManagerCcToPyWrapper(
299      tink_bindings.PublicKeyVerifyKeyManager.from_cc_registry(
300          _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE),
301      _rsa_ssa_pkcs1_alg_kid_from_public_key_data)
302  core.Registry.register_key_manager(public_key_manager, new_key_allowed=True)
303
304  private_key_manager = _JwtPublicKeySignKeyManagerCcToPyWrapper(
305      tink_bindings.PublicKeySignKeyManager.from_cc_registry(
306          _JWT_RSA_SSA_PSS_PRIVATE_KEY_TYPE),
307      _rsa_ssa_pss_alg_kid_from_private_key_data)
308  core.Registry.register_key_manager(private_key_manager, new_key_allowed=True)
309
310  public_key_manager = _JwtPublicKeyVerifyKeyManagerCcToPyWrapper(
311      tink_bindings.PublicKeyVerifyKeyManager.from_cc_registry(
312          _JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE),
313      _rsa_ssa_pss_alg_kid_from_public_key_data)
314  core.Registry.register_key_manager(public_key_manager, new_key_allowed=True)
315