1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Python primitive set wrapper for the JwtMac primitive.""" 15 16from typing import Type 17 18from tink.proto import tink_pb2 19from tink import core 20from tink.jwt import _jwt_error 21from tink.jwt import _jwt_format 22from tink.jwt import _jwt_public_key_sign 23from tink.jwt import _jwt_public_key_verify 24from tink.jwt import _jwt_validator 25from tink.jwt import _raw_jwt 26from tink.jwt import _verified_jwt 27 28 29class _WrappedJwtPublicKeySign(_jwt_public_key_sign.JwtPublicKeySign): 30 """A wrapped JwtPublicKeySign.""" 31 32 def __init__(self, pset: core.PrimitiveSet): 33 self._primitive_set = pset 34 35 def sign_and_encode(self, raw_jwt: _raw_jwt.RawJwt) -> str: 36 primary = self._primitive_set.primary() 37 kid = _jwt_format.get_kid(primary.key_id, primary.output_prefix_type) 38 return primary.primitive.sign_and_encode_with_kid(raw_jwt, kid) 39 40 41class _WrappedJwtPublicKeyVerify(_jwt_public_key_verify.JwtPublicKeyVerify): 42 """A wrapped JwtPublicKeyVerify.""" 43 44 def __init__(self, pset: core.PrimitiveSet): 45 self._primitive_set = pset 46 47 def verify_and_decode( 48 self, compact: str, 49 validator: _jwt_validator.JwtValidator) -> _verified_jwt.VerifiedJwt: 50 interesting_error = None 51 for entries in self._primitive_set.all(): 52 for entry in entries: 53 try: 54 kid = _jwt_format.get_kid(entry.key_id, entry.output_prefix_type) 55 return entry.primitive.verify_and_decode_with_kid( 56 compact, validator, kid) 57 except core.TinkError as e: 58 if isinstance(e, _jwt_error.JwtInvalidError): 59 interesting_error = e 60 pass 61 if interesting_error: 62 raise interesting_error 63 raise core.TinkError('invalid signature') 64 65 66def _validate_primitive_set(pset: core.PrimitiveSet): 67 for entries in pset.all(): 68 for entry in entries: 69 if (entry.output_prefix_type != tink_pb2.RAW and 70 entry.output_prefix_type != tink_pb2.TINK): 71 raise core.TinkError('unsupported OutputPrefixType') 72 73 74class _JwtPublicKeySignWrapper( 75 core.PrimitiveWrapper[_jwt_public_key_sign.JwtPublicKeySignInternal, 76 _jwt_public_key_sign.JwtPublicKeySign]): 77 """A wrapper for JwtPublicKeySign.""" 78 79 def wrap(self, 80 pset: core.PrimitiveSet) -> _jwt_public_key_sign.JwtPublicKeySign: 81 _validate_primitive_set(pset) 82 return _WrappedJwtPublicKeySign(pset) 83 84 def primitive_class(self) -> Type[_jwt_public_key_sign.JwtPublicKeySign]: 85 return _jwt_public_key_sign.JwtPublicKeySign 86 87 def input_primitive_class( 88 self) -> Type[_jwt_public_key_sign.JwtPublicKeySignInternal]: 89 return _jwt_public_key_sign.JwtPublicKeySignInternal 90 91 92class _JwtPublicKeyVerifyWrapper( 93 core.PrimitiveWrapper[_jwt_public_key_verify.JwtPublicKeyVerifyInternal, 94 _jwt_public_key_verify.JwtPublicKeyVerify]): 95 """A wrapper for JwtPublicKeyVerify.""" 96 97 def wrap( 98 self, 99 pset: core.PrimitiveSet) -> _jwt_public_key_verify.JwtPublicKeyVerify: 100 _validate_primitive_set(pset) 101 return _WrappedJwtPublicKeyVerify(pset) 102 103 def primitive_class(self) -> Type[_jwt_public_key_verify.JwtPublicKeyVerify]: 104 return _jwt_public_key_verify.JwtPublicKeyVerify 105 106 def input_primitive_class( 107 self) -> Type[_jwt_public_key_verify.JwtPublicKeyVerifyInternal]: 108 return _jwt_public_key_verify.JwtPublicKeyVerifyInternal 109 110 111def register(): 112 core.Registry.register_primitive_wrapper(_JwtPublicKeySignWrapper()) 113 core.Registry.register_primitive_wrapper(_JwtPublicKeyVerifyWrapper()) 114