1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Cross-language tests for the "kid" header set by JWT primitives.""" 15 16import base64 17import json 18from typing import Optional 19 20from absl.testing import absltest 21from absl.testing import parameterized 22import tink 23from tink import jwt 24 25from tink.proto import jwt_ecdsa_pb2 26from tink.proto import jwt_hmac_pb2 27from tink.proto import jwt_rsa_ssa_pkcs1_pb2 28from tink.proto import jwt_rsa_ssa_pss_pb2 29from tink.proto import tink_pb2 30from util import testing_servers 31from util import utilities 32 33SUPPORTED_LANGUAGES = testing_servers.SUPPORTED_LANGUAGES_BY_PRIMITIVE['jwt'] 34 35 36def setUpModule(): 37 jwt.register_jwt_mac() 38 jwt.register_jwt_signature() 39 testing_servers.start('jwt') 40 41 42def tearDownModule(): 43 testing_servers.stop() 44 45 46def base64_decode(encoded_data: bytes) -> bytes: 47 padded_encoded_data = encoded_data + b'===' 48 return base64.urlsafe_b64decode(padded_encoded_data) 49 50 51def decode_kid(compact: str) -> Optional[str]: 52 encoded_header, _, _ = compact.encode('utf8').split(b'.') 53 json_header = base64_decode(encoded_header) 54 header = json.loads(json_header) 55 return header.get('kid', None) 56 57 58def generate_jwt_mac_keyset_with_custom_kid( 59 template_name: str, custom_kid: str) -> tink_pb2.Keyset: 60 key_template = utilities.KEY_TEMPLATE[template_name] 61 keyset_handle = tink.new_keyset_handle(key_template) 62 # parse key_data.value, set custom_kid and serialize 63 key_data_value = keyset_handle._keyset.key[0].key_data.value 64 if template_name.startswith('JWT_HS256'): 65 hmac_key = jwt_hmac_pb2.JwtHmacKey.FromString(key_data_value) 66 hmac_key.custom_kid.value = custom_kid 67 key_data_value = hmac_key.SerializeToString() 68 else: 69 raise ValueError('unknown alg') 70 keyset_handle._keyset.key[0].key_data.value = key_data_value 71 return keyset_handle._keyset 72 73 74def generate_jwt_signature_keyset_with_custom_kid( 75 template_name: str, custom_kid: str) -> tink_pb2.Keyset: 76 key_template = utilities.KEY_TEMPLATE[template_name] 77 keyset_handle = tink.new_keyset_handle(key_template) 78 # parse key_data.value, set custom_kid and serialize 79 key_data_value = keyset_handle._keyset.key[0].key_data.value 80 if template_name.startswith('JWT_ES256'): 81 private_key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(key_data_value) 82 private_key.public_key.custom_kid.value = custom_kid 83 key_data_value = private_key.SerializeToString() 84 elif template_name.startswith('JWT_RS256'): 85 private_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PrivateKey.FromString( 86 key_data_value) 87 private_key.public_key.custom_kid.value = custom_kid 88 key_data_value = private_key.SerializeToString() 89 elif template_name.startswith('JWT_PS256'): 90 private_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPrivateKey.FromString( 91 key_data_value) 92 private_key.public_key.custom_kid.value = custom_kid 93 key_data_value = private_key.SerializeToString() 94 else: 95 raise ValueError('unknown template name') 96 keyset_handle._keyset.key[0].key_data.value = key_data_value 97 keyset = keyset_handle._keyset 98 return keyset 99 100 101class JwtKidTest(parameterized.TestCase): 102 """Tests that all JWT primitives consistently add a "kid" header to tokens.""" 103 104 @parameterized.parameters(['JWT_HS256']) 105 def test_jwt_mac_sets_kid_for_tink_templates(self, template_name): 106 key_template = utilities.KEY_TEMPLATE[template_name] 107 keyset = testing_servers.new_keyset('cc', key_template) 108 raw_jwt = jwt.new_raw_jwt(without_expiration=True) 109 for lang in SUPPORTED_LANGUAGES: 110 jwt_mac = testing_servers.remote_primitive(lang, keyset, jwt.JwtMac) 111 compact = jwt_mac.compute_mac_and_encode(raw_jwt) 112 self.assertIsNotNone(decode_kid(compact)) 113 114 @parameterized.parameters(['JWT_HS256_RAW']) 115 def test_jwt_mac_does_not_sets_kid_for_raw_templates(self, template_name): 116 key_template = utilities.KEY_TEMPLATE[template_name] 117 keyset = testing_servers.new_keyset('cc', key_template) 118 raw_jwt = jwt.new_raw_jwt(without_expiration=True) 119 for lang in SUPPORTED_LANGUAGES: 120 jwt_mac = testing_servers.remote_primitive(lang, keyset, jwt.JwtMac) 121 compact = jwt_mac.compute_mac_and_encode(raw_jwt) 122 self.assertIsNone(decode_kid(compact)) 123 124 @parameterized.parameters( 125 ['JWT_ES256', 'JWT_RS256_2048_F4', 'JWT_PS256_2048_F4']) 126 def test_jwt_public_key_sign_sets_kid_for_tink_templates(self, template_name): 127 key_template = utilities.KEY_TEMPLATE[template_name] 128 keyset = testing_servers.new_keyset('cc', key_template) 129 raw_jwt = jwt.new_raw_jwt(without_expiration=True) 130 supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[ 131 template_name] 132 for lang in supported_langs: 133 jwt_sign = testing_servers.remote_primitive(lang, keyset, 134 jwt.JwtPublicKeySign) 135 compact = jwt_sign.sign_and_encode(raw_jwt) 136 self.assertIsNotNone(decode_kid(compact)) 137 138 @parameterized.parameters( 139 ['JWT_ES256_RAW', 'JWT_RS256_2048_F4_RAW', 'JWT_PS256_2048_F4_RAW']) 140 def test_jwt_public_key_sign_does_not_sets_kid_for_raw_templates( 141 self, template_name): 142 key_template = utilities.KEY_TEMPLATE[template_name] 143 keyset = testing_servers.new_keyset('cc', key_template) 144 raw_jwt = jwt.new_raw_jwt(without_expiration=True) 145 supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[ 146 template_name] 147 for lang in supported_langs: 148 jwt_sign = testing_servers.remote_primitive(lang, keyset, 149 jwt.JwtPublicKeySign) 150 compact = jwt_sign.sign_and_encode(raw_jwt) 151 self.assertIsNone(decode_kid(compact)) 152 153 @parameterized.parameters(['JWT_HS256_RAW']) 154 def test_jwt_mac_sets_custom_kid_for_raw_keys(self, template_name): 155 keyset = generate_jwt_mac_keyset_with_custom_kid( 156 template_name=template_name, custom_kid='my kid') 157 raw_jwt = jwt.new_raw_jwt(without_expiration=True) 158 for lang in SUPPORTED_LANGUAGES: 159 jwt_mac = testing_servers.remote_primitive(lang, 160 keyset.SerializeToString(), 161 jwt.JwtMac) 162 compact = jwt_mac.compute_mac_and_encode(raw_jwt) 163 self.assertEqual(decode_kid(compact), 'my kid') 164 165 @parameterized.parameters(['JWT_HS256']) 166 def test_jwt_mac_fails_for_tink_keys_with_custom_kid(self, template_name): 167 keyset = generate_jwt_mac_keyset_with_custom_kid( 168 template_name=template_name, custom_kid='my kid') 169 raw_jwt = jwt.new_raw_jwt(without_expiration=True) 170 for lang in SUPPORTED_LANGUAGES: 171 with self.assertRaises( 172 tink.TinkError, 173 msg=('%s supports JWT mac keys with TINK output prefix type ' 174 'and custom_kid set unexpectedly') % lang): 175 jwt_mac = testing_servers.remote_primitive(lang, 176 keyset.SerializeToString(), 177 jwt.JwtMac) 178 jwt_mac.compute_mac_and_encode(raw_jwt) 179 180 @parameterized.parameters( 181 ['JWT_ES256_RAW', 'JWT_RS256_2048_F4_RAW', 'JWT_PS256_2048_F4_RAW']) 182 def test_jwt_public_key_sign_sets_custom_kid_for_raw_keys( 183 self, template_name): 184 keyset = generate_jwt_signature_keyset_with_custom_kid( 185 template_name=template_name, custom_kid='my kid') 186 raw_jwt = jwt.new_raw_jwt(without_expiration=True) 187 supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[ 188 template_name] 189 for lang in supported_langs: 190 jwt_sign = testing_servers.remote_primitive(lang, 191 keyset.SerializeToString(), 192 jwt.JwtPublicKeySign) 193 compact = jwt_sign.sign_and_encode(raw_jwt) 194 self.assertEqual(decode_kid(compact), 'my kid') 195 196 @parameterized.parameters( 197 ['JWT_ES256', 'JWT_RS256_2048_F4', 'JWT_PS256_2048_F4']) 198 def test_jwt_public_key_sign_fails_for_tink_keys_with_custom_kid( 199 self, template_name): 200 keyset = generate_jwt_signature_keyset_with_custom_kid( 201 template_name=template_name, custom_kid='my kid') 202 raw_jwt = jwt.new_raw_jwt(without_expiration=True) 203 for lang in SUPPORTED_LANGUAGES: 204 with self.assertRaises( 205 tink.TinkError, 206 msg=('%s supports JWT signature keys with TINK output prefix type ' 207 'and custom_kid set unexpectedly') % lang): 208 jwt_sign = testing_servers.remote_primitive(lang, 209 keyset.SerializeToString(), 210 jwt.JwtPublicKeySign) 211 jwt_sign.sign_and_encode(raw_jwt) 212 213 214if __name__ == '__main__': 215 absltest.main() 216