1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""JWT Signature key managers.""" 15 16from typing import Any, Optional, Type, Tuple, Callable 17 18from tink.proto import jwt_ecdsa_pb2 19from tink.proto import jwt_rsa_ssa_pkcs1_pb2 20from tink.proto import jwt_rsa_ssa_pss_pb2 21from tink.proto import tink_pb2 22from tink import core 23from tink.cc.pybind import tink_bindings 24from tink.jwt import _json_util 25from tink.jwt import _jwt_error 26from tink.jwt import _jwt_format 27from tink.jwt import _jwt_public_key_sign 28from tink.jwt import _jwt_public_key_verify 29from tink.jwt import _jwt_validator 30from tink.jwt import _raw_jwt 31from tink.jwt import _verified_jwt 32 33_JWT_ECDSA_PRIVATE_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtEcdsaPrivateKey' 34_JWT_ECDSA_PUBLIC_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtEcdsaPublicKey' 35 36_JWT_RSA_SSA_PKCS1_PRIVATE_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PrivateKey' 37_JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PublicKey' 38 39_JWT_RSA_SSA_PSS_PRIVATE_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPrivateKey' 40_JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE = 'type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPublicKey' 41 42_ECDSA_ALGORITHM_TEXTS = { 43 jwt_ecdsa_pb2.ES256: 'ES256', 44 jwt_ecdsa_pb2.ES384: 'ES384', 45 jwt_ecdsa_pb2.ES512: 'ES512' 46} 47 48_RSA_SSA_PKCS1_ALGORITHM_TEXTS = { 49 jwt_rsa_ssa_pkcs1_pb2.RS256: 'RS256', 50 jwt_rsa_ssa_pkcs1_pb2.RS384: 'RS384', 51 jwt_rsa_ssa_pkcs1_pb2.RS512: 'RS512' 52} 53 54_RSA_SSA_PSS_ALGORITHM_TEXTS = { 55 jwt_rsa_ssa_pss_pb2.PS256: 'PS256', 56 jwt_rsa_ssa_pss_pb2.PS384: 'PS384', 57 jwt_rsa_ssa_pss_pb2.PS512: 'PS512' 58} 59 60 61class _JwtPublicKeySign(_jwt_public_key_sign.JwtPublicKeySignInternal): 62 """Implementation of JwtPublicKeySignInternal using a PublicKeySign.""" 63 64 def __init__(self, cc_primitive: tink_bindings.PublicKeySign, algorithm: str, 65 custom_kid: str): 66 self._public_key_sign = cc_primitive 67 self._algorithm = algorithm 68 self._custom_kid = custom_kid 69 70 @core.use_tink_errors 71 def _sign(self, data: bytes) -> bytes: 72 return self._public_key_sign.sign(data) 73 74 def sign_and_encode_with_kid(self, raw_jwt: _raw_jwt.RawJwt, 75 kid: Optional[str]) -> str: 76 """Computes a signature and encodes the token. 77 78 Args: 79 raw_jwt: The RawJwt token to be MACed and encoded. 80 kid: Optional "kid" header value. It is set by the wrapper for keys with 81 output prefix TINK, and it is None for output prefix RAW. 82 83 Returns: 84 The MACed token encoded in the JWS compact serialization format. 85 Raises: 86 tink.TinkError if the operation fails. 87 """ 88 if self._custom_kid is not None: 89 if kid is not None: 90 raise _jwt_error.JwtInvalidError( 91 'custom_kid must not be set for keys with output prefix type TINK') 92 kid = self._custom_kid 93 unsigned = _jwt_format.create_unsigned_compact(self._algorithm, kid, 94 raw_jwt) 95 return _jwt_format.create_signed_compact(unsigned, self._sign(unsigned)) 96 97 98class _JwtPublicKeyVerify(_jwt_public_key_verify.JwtPublicKeyVerifyInternal): 99 """Implementation of JwtPublicKeyVerify using a PublicKeyVerify.""" 100 101 def __init__(self, cc_primitive: tink_bindings.PublicKeyVerify, 102 algorithm: str, custom_kid: Optional[str]): 103 self._public_key_verify = cc_primitive 104 self._algorithm = algorithm 105 self._custom_kid = custom_kid 106 107 @core.use_tink_errors 108 def _verify(self, signature: bytes, data: bytes) -> None: 109 self._public_key_verify.verify(signature, data) 110 111 def verify_and_decode_with_kid( 112 self, compact: str, validator: _jwt_validator.JwtValidator, 113 kid: Optional[str]) -> _verified_jwt.VerifiedJwt: 114 """Verifies, validates and decodes a signed compact JWT token.""" 115 parts = _jwt_format.split_signed_compact(compact) 116 unsigned_compact, json_header, json_payload, signature = parts 117 self._verify(signature, unsigned_compact) 118 header = _json_util.json_loads(json_header) 119 _jwt_format.validate_header( 120 header=header, 121 algorithm=self._algorithm, 122 tink_kid=kid, 123 custom_kid=self._custom_kid) 124 raw_jwt = _raw_jwt.raw_jwt_from_json( 125 _jwt_format.get_type_header(header), json_payload) 126 _jwt_validator.validate(validator, raw_jwt) 127 return _verified_jwt.VerifiedJwt._create(raw_jwt) # pylint: disable=protected-access 128 129 130class _JwtPublicKeySignKeyManagerCcToPyWrapper( 131 core.PrivateKeyManager[_jwt_public_key_sign.JwtPublicKeySignInternal]): 132 """Converts a C++ sign key manager into a JwtPublicKeySignKeyManager.""" 133 134 def __init__(self, cc_key_manager: tink_bindings.PublicKeySignKeyManager, 135 key_data_to_alg_kid: Callable[[tink_pb2.KeyData], 136 Tuple[str, Optional[str]]]): 137 self._cc_key_manager = cc_key_manager 138 self._key_data_to_alg_kid = key_data_to_alg_kid 139 140 def primitive_class( 141 self) -> Type[_jwt_public_key_sign.JwtPublicKeySignInternal]: 142 return _jwt_public_key_sign.JwtPublicKeySignInternal 143 144 @core.use_tink_errors 145 def primitive( 146 self, key_data: tink_pb2.KeyData 147 ) -> _jwt_public_key_sign.JwtPublicKeySignInternal: 148 sign = self._cc_key_manager.primitive(key_data.SerializeToString()) 149 algorithm, custom_kid = self._key_data_to_alg_kid(key_data) 150 return _JwtPublicKeySign(sign, algorithm, custom_kid) 151 152 def key_type(self) -> str: 153 return self._cc_key_manager.key_type() 154 155 @core.use_tink_errors 156 def new_key_data(self, 157 key_template: tink_pb2.KeyTemplate) -> tink_pb2.KeyData: 158 return tink_pb2.KeyData.FromString( 159 self._cc_key_manager.new_key_data(key_template.SerializeToString())) 160 161 @core.use_tink_errors 162 def public_key_data(self, key_data: tink_pb2.KeyData) -> tink_pb2.KeyData: 163 return tink_pb2.KeyData.FromString( 164 self._cc_key_manager.public_key_data(key_data.SerializeToString())) 165 166 167class _JwtPublicKeyVerifyKeyManagerCcToPyWrapper( 168 core.KeyManager[_jwt_public_key_verify.JwtPublicKeyVerifyInternal]): 169 """Converts a C++ verify key manager into a JwtPublicKeyVerifyKeyManager.""" 170 171 def __init__(self, cc_key_manager: tink_bindings.PublicKeyVerifyKeyManager, 172 key_data_to_alg_kid: Callable[[tink_pb2.KeyData], 173 Tuple[str, Optional[str]]]): 174 self._cc_key_manager = cc_key_manager 175 self._key_data_to_alg_kid = key_data_to_alg_kid 176 177 def primitive_class( 178 self) -> Type[_jwt_public_key_verify.JwtPublicKeyVerifyInternal]: 179 return _jwt_public_key_verify.JwtPublicKeyVerifyInternal 180 181 @core.use_tink_errors 182 def primitive( 183 self, key_data: tink_pb2.KeyData 184 ) -> _jwt_public_key_verify.JwtPublicKeyVerifyInternal: 185 verify = self._cc_key_manager.primitive(key_data.SerializeToString()) 186 algorithm, custom_kid = self._key_data_to_alg_kid(key_data) 187 return _JwtPublicKeyVerify(verify, algorithm, custom_kid) 188 189 def key_type(self) -> str: 190 return self._cc_key_manager.key_type() 191 192 @core.use_tink_errors 193 def new_key_data(self, 194 key_template: tink_pb2.KeyTemplate) -> tink_pb2.KeyData: 195 return tink_pb2.KeyData.FromString( 196 self._cc_key_manager.new_key_data(key_template.SerializeToString())) 197 198 199def _ecdsa_algorithm_text(algorithm: jwt_ecdsa_pb2.JwtEcdsaAlgorithm) -> str: 200 if algorithm not in _ECDSA_ALGORITHM_TEXTS: 201 raise _jwt_error.JwtInvalidError('Invalid algorithm') 202 return _ECDSA_ALGORITHM_TEXTS[algorithm] 203 204 205def _get_custom_kid(public_key_proto: Any) -> Optional[str]: 206 if public_key_proto.HasField('custom_kid'): 207 return public_key_proto.custom_kid.value 208 else: 209 return None 210 211 212def _ecdsa_alg_kid_from_private_key_data( 213 key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]: 214 if key_data.type_url != _JWT_ECDSA_PRIVATE_KEY_TYPE: 215 raise _jwt_error.JwtInvalidError('Invalid key data key type') 216 key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(key_data.value) 217 return (_ecdsa_algorithm_text(key.public_key.algorithm), 218 _get_custom_kid(key.public_key)) 219 220 221def _ecdsa_alg_kid_from_public_key_data( 222 key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]: 223 if key_data.type_url != _JWT_ECDSA_PUBLIC_KEY_TYPE: 224 raise _jwt_error.JwtInvalidError('Invalid key data key type') 225 key = jwt_ecdsa_pb2.JwtEcdsaPublicKey.FromString(key_data.value) 226 return (_ecdsa_algorithm_text(key.algorithm), _get_custom_kid(key)) 227 228 229def _rsa_ssa_pkcs1_algorithm_text( 230 algorithm: jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1Algorithm) -> str: 231 if algorithm not in _RSA_SSA_PKCS1_ALGORITHM_TEXTS: 232 raise _jwt_error.JwtInvalidError('Invalid algorithm') 233 return _RSA_SSA_PKCS1_ALGORITHM_TEXTS[algorithm] 234 235 236def _rsa_ssa_pkcs1_alg_kid_from_private_key_data( 237 key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]: 238 if key_data.type_url != _JWT_RSA_SSA_PKCS1_PRIVATE_KEY_TYPE: 239 raise _jwt_error.JwtInvalidError('Invalid key data key type') 240 key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PrivateKey.FromString( 241 key_data.value) 242 return (_rsa_ssa_pkcs1_algorithm_text(key.public_key.algorithm), 243 _get_custom_kid(key.public_key)) 244 245 246def _rsa_ssa_pkcs1_alg_kid_from_public_key_data( 247 key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]: 248 if key_data.type_url != _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE: 249 raise _jwt_error.JwtInvalidError('Invalid key data key type') 250 key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PublicKey.FromString(key_data.value) 251 return (_rsa_ssa_pkcs1_algorithm_text(key.algorithm), _get_custom_kid(key)) 252 253 254def _rsa_ssa_pss_algorithm_text( 255 algorithm: jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssAlgorithm) -> str: 256 if algorithm not in _RSA_SSA_PSS_ALGORITHM_TEXTS: 257 raise _jwt_error.JwtInvalidError('Invalid algorithm') 258 return _RSA_SSA_PSS_ALGORITHM_TEXTS[algorithm] 259 260 261def _rsa_ssa_pss_alg_kid_from_private_key_data( 262 key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]: 263 if key_data.type_url != _JWT_RSA_SSA_PSS_PRIVATE_KEY_TYPE: 264 raise _jwt_error.JwtInvalidError('Invalid key data key type') 265 key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPrivateKey.FromString(key_data.value) 266 return (_rsa_ssa_pss_algorithm_text(key.public_key.algorithm), 267 _get_custom_kid(key.public_key)) 268 269 270def _rsa_ssa_pss_alg_kid_from_public_key_data( 271 key_data: tink_pb2.KeyData) -> Tuple[str, Optional[str]]: 272 if key_data.type_url != _JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE: 273 raise _jwt_error.JwtInvalidError('Invalid key data key type') 274 key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPublicKey.FromString(key_data.value) 275 return (_rsa_ssa_pss_algorithm_text(key.algorithm), _get_custom_kid(key)) 276 277 278def register(): 279 """Registers all JWT signature primitives.""" 280 tink_bindings.register_jwt() 281 282 private_key_manager = _JwtPublicKeySignKeyManagerCcToPyWrapper( 283 tink_bindings.PublicKeySignKeyManager.from_cc_registry( 284 _JWT_ECDSA_PRIVATE_KEY_TYPE), _ecdsa_alg_kid_from_private_key_data) 285 core.Registry.register_key_manager(private_key_manager, new_key_allowed=True) 286 287 public_key_manager = _JwtPublicKeyVerifyKeyManagerCcToPyWrapper( 288 tink_bindings.PublicKeyVerifyKeyManager.from_cc_registry( 289 _JWT_ECDSA_PUBLIC_KEY_TYPE), _ecdsa_alg_kid_from_public_key_data) 290 core.Registry.register_key_manager(public_key_manager, new_key_allowed=True) 291 292 private_key_manager = _JwtPublicKeySignKeyManagerCcToPyWrapper( 293 tink_bindings.PublicKeySignKeyManager.from_cc_registry( 294 _JWT_RSA_SSA_PKCS1_PRIVATE_KEY_TYPE), 295 _rsa_ssa_pkcs1_alg_kid_from_private_key_data) 296 core.Registry.register_key_manager(private_key_manager, new_key_allowed=True) 297 298 public_key_manager = _JwtPublicKeyVerifyKeyManagerCcToPyWrapper( 299 tink_bindings.PublicKeyVerifyKeyManager.from_cc_registry( 300 _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE), 301 _rsa_ssa_pkcs1_alg_kid_from_public_key_data) 302 core.Registry.register_key_manager(public_key_manager, new_key_allowed=True) 303 304 private_key_manager = _JwtPublicKeySignKeyManagerCcToPyWrapper( 305 tink_bindings.PublicKeySignKeyManager.from_cc_registry( 306 _JWT_RSA_SSA_PSS_PRIVATE_KEY_TYPE), 307 _rsa_ssa_pss_alg_kid_from_private_key_data) 308 core.Registry.register_key_manager(private_key_manager, new_key_allowed=True) 309 310 public_key_manager = _JwtPublicKeyVerifyKeyManagerCcToPyWrapper( 311 tink_bindings.PublicKeyVerifyKeyManager.from_cc_registry( 312 _JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE), 313 _rsa_ssa_pss_alg_kid_from_public_key_data) 314 core.Registry.register_key_manager(public_key_manager, new_key_allowed=True) 315