# Copyright 2021 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS-IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """JWT testing service API implementations in Python.""" import datetime import io import json from typing import Tuple import grpc import tink from tink import cleartext_keyset_handle from tink import jwt from google.protobuf import duration_pb2 from google.protobuf import timestamp_pb2 from protos import testing_api_pb2 from protos import testing_api_pb2_grpc def _to_timestamp_tuple(t: datetime.datetime) -> Tuple[int, int]: if not t.tzinfo: raise ValueError('datetime must have tzinfo') seconds = int(t.timestamp()) nanos = int((t.timestamp() - seconds) * 1e9) return (seconds, nanos) def _from_timestamp_proto( timestamp: timestamp_pb2.Timestamp) -> datetime.datetime: t = timestamp.seconds + (timestamp.nanos / 1e9) return datetime.datetime.fromtimestamp(t, datetime.timezone.utc) def _from_duration_proto( duration: duration_pb2.Duration) -> datetime.timedelta: return datetime.timedelta(seconds=duration.seconds) def raw_jwt_from_proto(proto_raw_jwt: testing_api_pb2.JwtToken) -> jwt.RawJwt: """Converts a proto JwtToken into a jwt.RawJwt.""" type_header = None if proto_raw_jwt.HasField('type_header'): type_header = proto_raw_jwt.type_header.value issuer = None if proto_raw_jwt.HasField('issuer'): issuer = proto_raw_jwt.issuer.value subject = None if proto_raw_jwt.HasField('subject'): subject = proto_raw_jwt.subject.value audiences = list(proto_raw_jwt.audiences) if not audiences: audiences = None jwt_id = None if proto_raw_jwt.HasField('jwt_id'): jwt_id = proto_raw_jwt.jwt_id.value custom_claims = {} for name, claim in proto_raw_jwt.custom_claims.items(): if claim.HasField('null_value'): custom_claims[name] = None elif claim.HasField('number_value'): custom_claims[name] = claim.number_value elif claim.HasField('string_value'): custom_claims[name] = claim.string_value elif claim.HasField('bool_value'): custom_claims[name] = claim.bool_value elif claim.HasField('json_object_value'): custom_claims[name] = json.loads(claim.json_object_value) elif claim.HasField('json_array_value'): custom_claims[name] = json.loads(claim.json_array_value) else: raise ValueError('claim %s has unknown type' % name) expiration = None if proto_raw_jwt.HasField('expiration'): expiration = _from_timestamp_proto(proto_raw_jwt.expiration) not_before = None if proto_raw_jwt.HasField('not_before'): not_before = _from_timestamp_proto(proto_raw_jwt.not_before) issued_at = None if proto_raw_jwt.HasField('issued_at'): issued_at = _from_timestamp_proto(proto_raw_jwt.issued_at) without_expiration = not expiration return jwt.new_raw_jwt( type_header=type_header, issuer=issuer, subject=subject, audiences=audiences, jwt_id=jwt_id, expiration=expiration, without_expiration=without_expiration, not_before=not_before, issued_at=issued_at, custom_claims=custom_claims) def verifiedjwt_to_proto( verified_jwt: jwt.VerifiedJwt) -> testing_api_pb2.JwtToken: """Converts a jwt.VerifiedJwt into a proto JwtToken.""" token = testing_api_pb2.JwtToken() if verified_jwt.has_type_header(): token.type_header.value = verified_jwt.type_header() if verified_jwt.has_issuer(): token.issuer.value = verified_jwt.issuer() if verified_jwt.has_subject(): token.subject.value = verified_jwt.subject() if verified_jwt.has_audiences(): token.audiences.extend(verified_jwt.audiences()) if verified_jwt.has_jwt_id(): token.jwt_id.value = verified_jwt.jwt_id() if verified_jwt.has_expiration(): seconds, nanos = _to_timestamp_tuple(verified_jwt.expiration()) token.expiration.seconds = seconds token.expiration.nanos = nanos if verified_jwt.has_not_before(): seconds, nanos = _to_timestamp_tuple(verified_jwt.not_before()) token.not_before.seconds = seconds token.not_before.nanos = nanos if verified_jwt.has_issued_at(): seconds, nanos = _to_timestamp_tuple(verified_jwt.issued_at()) token.issued_at.seconds = seconds token.issued_at.nanos = nanos for name in verified_jwt.custom_claim_names(): value = verified_jwt.custom_claim(name) if value is None: token.custom_claims[name].null_value = testing_api_pb2.NULL_VALUE elif isinstance(value, bool): token.custom_claims[name].bool_value = value elif isinstance(value, (int, float)): token.custom_claims[name].number_value = value elif isinstance(value, str): token.custom_claims[name].string_value = value elif isinstance(value, dict): token.custom_claims[name].json_object_value = json.dumps(value) elif isinstance(value, list): token.custom_claims[name].json_array_value = json.dumps(value) else: raise ValueError('claim %s has unknown type' % name) return token def validator_from_proto( proto_validator: testing_api_pb2.JwtValidator) -> jwt.JwtValidator: """Converts a proto JwtValidator into a JwtValidator.""" expected_type_header = None if proto_validator.HasField('expected_type_header'): expected_type_header = proto_validator.expected_type_header.value expected_issuer = None if proto_validator.HasField('expected_issuer'): expected_issuer = proto_validator.expected_issuer.value expected_audience = None if proto_validator.HasField('expected_audience'): expected_audience = proto_validator.expected_audience.value fixed_now = None if proto_validator.HasField('now'): fixed_now = _from_timestamp_proto(proto_validator.now) clock_skew = None if proto_validator.HasField('clock_skew'): clock_skew = _from_duration_proto(proto_validator.clock_skew) return jwt.new_validator( expected_type_header=expected_type_header, expected_issuer=expected_issuer, expected_audience=expected_audience, ignore_type_header=proto_validator.ignore_type_header, ignore_issuer=proto_validator.ignore_issuer, ignore_audiences=proto_validator.ignore_audience, allow_missing_expiration=proto_validator.allow_missing_expiration, expect_issued_in_the_past=proto_validator.expect_issued_in_the_past, fixed_now=fixed_now, clock_skew=clock_skew) class JwtServicer(testing_api_pb2_grpc.JwtServicer): """A service for signing and verifying JWTs.""" def CreateJwtMac( self, request: testing_api_pb2.CreationRequest, context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: """Creates a JwtMac without using it.""" try: keyset_handle = cleartext_keyset_handle.read( tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) keyset_handle.primitive(jwt.JwtMac) return testing_api_pb2.CreationResponse() except tink.TinkError as e: return testing_api_pb2.CreationResponse(err=str(e)) def CreateJwtPublicKeySign( self, request: testing_api_pb2.CreationRequest, context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: """Creates a JwtPublicKeySign without using it.""" try: keyset_handle = cleartext_keyset_handle.read( tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) keyset_handle.primitive(jwt.JwtPublicKeySign) return testing_api_pb2.CreationResponse() except tink.TinkError as e: return testing_api_pb2.CreationResponse(err=str(e)) def CreateJwtPublicKeyVerify( self, request: testing_api_pb2.CreationRequest, context: grpc.ServicerContext) -> testing_api_pb2.CreationResponse: """Creates a JwtPublicKeyVerify without using it.""" try: keyset_handle = cleartext_keyset_handle.read( tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) keyset_handle.primitive(jwt.JwtPublicKeyVerify) return testing_api_pb2.CreationResponse() except tink.TinkError as e: return testing_api_pb2.CreationResponse(err=str(e)) def ComputeMacAndEncode( self, request: testing_api_pb2.JwtSignRequest, context: grpc.ServicerContext) -> testing_api_pb2.JwtSignResponse: """Computes a MACed compact JWT.""" try: keyset_handle = cleartext_keyset_handle.read( tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) p = keyset_handle.primitive(jwt.JwtMac) raw_jwt = raw_jwt_from_proto(request.raw_jwt) signed_compact_jwt = p.compute_mac_and_encode(raw_jwt) return testing_api_pb2.JwtSignResponse( signed_compact_jwt=signed_compact_jwt) except tink.TinkError as e: return testing_api_pb2.JwtSignResponse(err=str(e)) def VerifyMacAndDecode( self, request: testing_api_pb2.JwtVerifyRequest, context: grpc.ServicerContext) -> testing_api_pb2.JwtVerifyResponse: """Verifies a MAC value.""" try: keyset_handle = cleartext_keyset_handle.read( tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) validator = validator_from_proto(request.validator) p = keyset_handle.primitive(jwt.JwtMac) verified_jwt = p.verify_mac_and_decode(request.signed_compact_jwt, validator) return testing_api_pb2.JwtVerifyResponse( verified_jwt=verifiedjwt_to_proto(verified_jwt)) except tink.TinkError as e: return testing_api_pb2.JwtVerifyResponse(err=str(e)) def PublicKeySignAndEncode( self, request: testing_api_pb2.JwtSignRequest, context: grpc.ServicerContext) -> testing_api_pb2.JwtSignResponse: """Computes a signed compact JWT token.""" try: keyset_handle = cleartext_keyset_handle.read( tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) p = keyset_handle.primitive(jwt.JwtPublicKeySign) raw_jwt = raw_jwt_from_proto(request.raw_jwt) signed_compact_jwt = p.sign_and_encode(raw_jwt) return testing_api_pb2.JwtSignResponse( signed_compact_jwt=signed_compact_jwt) except tink.TinkError as e: return testing_api_pb2.JwtSignResponse(err=str(e)) def PublicKeyVerifyAndDecode( self, request: testing_api_pb2.JwtVerifyRequest, context: grpc.ServicerContext) -> testing_api_pb2.JwtVerifyResponse: """Verifies the validity of the signed compact JWT token.""" try: keyset_handle = cleartext_keyset_handle.read( tink.BinaryKeysetReader(request.annotated_keyset.serialized_keyset)) validator = validator_from_proto(request.validator) p = keyset_handle.primitive(jwt.JwtPublicKeyVerify) verified_jwt = p.verify_and_decode(request.signed_compact_jwt, validator) return testing_api_pb2.JwtVerifyResponse( verified_jwt=verifiedjwt_to_proto(verified_jwt)) except tink.TinkError as e: return testing_api_pb2.JwtVerifyResponse(err=str(e)) def ToJwkSet( self, request: testing_api_pb2.JwtToJwkSetRequest, context: grpc.ServicerContext) -> testing_api_pb2.JwtToJwkSetResponse: """Converts a Tink Keyset with JWT keys into a JWK set.""" try: keyset_handle = cleartext_keyset_handle.read( tink.BinaryKeysetReader(request.keyset)) jwk_set = jwt.jwk_set_from_public_keyset_handle(keyset_handle) return testing_api_pb2.JwtToJwkSetResponse(jwk_set=jwk_set) except tink.TinkError as e: return testing_api_pb2.JwtToJwkSetResponse(err=str(e)) def FromJwkSet( self, request: testing_api_pb2.JwtFromJwkSetRequest, context: grpc.ServicerContext) -> testing_api_pb2.JwtFromJwkSetResponse: """Converts a JWK set into a Tink Keyset.""" try: keyset_handle = jwt.jwk_set_to_public_keyset_handle(request.jwk_set) keyset = io.BytesIO() cleartext_keyset_handle.write( tink.BinaryKeysetWriter(keyset), keyset_handle) return testing_api_pb2.JwtFromJwkSetResponse(keyset=keyset.getvalue()) except tink.TinkError as e: return testing_api_pb2.JwtFromJwkSetResponse(err=str(e))