1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tink.python.tink.jwt._jwt_signature_wrappers_test.""" 15 16import io 17 18 19from absl.testing import absltest 20from absl.testing import parameterized 21 22from tink.proto import jwt_ecdsa_pb2 23from tink.proto import jwt_rsa_ssa_pkcs1_pb2 24from tink.proto import jwt_rsa_ssa_pss_pb2 25from tink.proto import tink_pb2 26import tink 27from tink import cleartext_keyset_handle 28from tink import jwt 29from tink.jwt import _json_util 30from tink.jwt import _jwt_format 31 32from tink.testing import keyset_builder 33 34 35LONG_CUSTOM_KID = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit' 36 37 38def setUpModule(): 39 jwt.register_jwt_signature() 40 41 42def _change_key_id(keyset_handle: tink.KeysetHandle) -> tink.KeysetHandle: 43 """Changes the key id of the first key and makes it primary.""" 44 buffer = io.BytesIO() 45 cleartext_keyset_handle.write( 46 tink.BinaryKeysetWriter(buffer), keyset_handle) 47 keyset = tink_pb2.Keyset.FromString(buffer.getvalue()) 48 # XOR the key id with an arbitrary 32-bit string to get a new key id. 49 new_key_id = keyset.key[0].key_id ^ 0xdeadbeef 50 keyset.key[0].key_id = new_key_id 51 keyset.primary_key_id = new_key_id 52 return cleartext_keyset_handle.from_keyset(keyset) 53 54 55def _change_output_prefix_to_tink( 56 keyset_handle: tink.KeysetHandle) -> tink.KeysetHandle: 57 """Changes the output prefix type of the first key to TINK.""" 58 buffer = io.BytesIO() 59 cleartext_keyset_handle.write( 60 tink.BinaryKeysetWriter(buffer), keyset_handle) 61 keyset = tink_pb2.Keyset.FromString(buffer.getvalue()) 62 keyset.key[0].output_prefix_type = tink_pb2.TINK 63 return cleartext_keyset_handle.from_keyset(keyset) 64 65 66def _set_custom_kid(keyset_handle: tink.KeysetHandle, 67 custom_kid: str) -> tink.KeysetHandle: 68 """Sets the custom_kid field of the first key.""" 69 buffer = io.BytesIO() 70 cleartext_keyset_handle.write( 71 tink.BinaryKeysetWriter(buffer), keyset_handle) 72 keyset = tink_pb2.Keyset.FromString(buffer.getvalue()) 73 if keyset.key[0].key_data.type_url.endswith('JwtEcdsaPrivateKey'): 74 jwt_ecdsa_key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString( 75 keyset.key[0].key_data.value) 76 jwt_ecdsa_key.public_key.custom_kid.value = custom_kid 77 keyset.key[0].key_data.value = jwt_ecdsa_key.SerializeToString() 78 elif keyset.key[0].key_data.type_url.endswith('JwtRsaSsaPkcs1PrivateKey'): 79 rsa_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PrivateKey.FromString( 80 keyset.key[0].key_data.value) 81 rsa_key.public_key.custom_kid.value = custom_kid 82 keyset.key[0].key_data.value = rsa_key.SerializeToString() 83 elif keyset.key[0].key_data.type_url.endswith('JwtRsaSsaPssPrivateKey'): 84 rsa_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPrivateKey.FromString( 85 keyset.key[0].key_data.value) 86 rsa_key.public_key.custom_kid.value = custom_kid 87 keyset.key[0].key_data.value = rsa_key.SerializeToString() 88 else: 89 raise tink.TinkError('unknown key type') 90 return cleartext_keyset_handle.from_keyset(keyset) 91 92 93class JwtSignatureWrapperTest(parameterized.TestCase): 94 95 def test_interesting_error(self): 96 private_handle = tink.new_keyset_handle(jwt.jwt_es256_template()) 97 sign = private_handle.primitive(jwt.JwtPublicKeySign) 98 verify = private_handle.public_keyset_handle().primitive( 99 jwt.JwtPublicKeyVerify) 100 raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True) 101 compact = sign.sign_and_encode(raw_jwt) 102 with self.assertRaisesRegex(jwt.JwtInvalidError, 103 'invalid JWT; expected issuer'): 104 verify.verify_and_decode(compact, jwt.new_validator( 105 expected_issuer='unknown', allow_missing_expiration=True)) 106 107 @parameterized.parameters([ 108 (jwt.raw_jwt_es256_template(), jwt.raw_jwt_es256_template()), 109 (jwt.raw_jwt_es256_template(), jwt.jwt_es256_template()), 110 (jwt.jwt_es256_template(), jwt.raw_jwt_es256_template()), 111 (jwt.jwt_es256_template(), jwt.jwt_es256_template()), 112 ]) 113 def test_key_rotation(self, old_key_tmpl, new_key_tmpl): 114 builder = keyset_builder.new_keyset_builder() 115 older_key_id = builder.add_new_key(old_key_tmpl) 116 117 builder.set_primary_key(older_key_id) 118 handle1 = builder.keyset_handle() 119 sign1 = handle1.primitive(jwt.JwtPublicKeySign) 120 verify1 = handle1.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 121 122 newer_key_id = builder.add_new_key(new_key_tmpl) 123 handle2 = builder.keyset_handle() 124 sign2 = handle2.primitive(jwt.JwtPublicKeySign) 125 verify2 = handle2.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 126 127 builder.set_primary_key(newer_key_id) 128 handle3 = builder.keyset_handle() 129 sign3 = handle3.primitive(jwt.JwtPublicKeySign) 130 verify3 = handle3.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 131 132 builder.disable_key(older_key_id) 133 handle4 = builder.keyset_handle() 134 sign4 = handle4.primitive(jwt.JwtPublicKeySign) 135 verify4 = handle4.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 136 137 raw_jwt = jwt.new_raw_jwt(issuer='a', without_expiration=True) 138 validator = jwt.new_validator( 139 expected_issuer='a', allow_missing_expiration=True) 140 141 self.assertNotEqual(older_key_id, newer_key_id) 142 # 1 uses the older key. So 1, 2 and 3 can verify the signature, but not 4. 143 compact1 = sign1.sign_and_encode(raw_jwt) 144 self.assertEqual( 145 verify1.verify_and_decode(compact1, validator).issuer(), 'a') 146 self.assertEqual( 147 verify2.verify_and_decode(compact1, validator).issuer(), 'a') 148 self.assertEqual( 149 verify3.verify_and_decode(compact1, validator).issuer(), 'a') 150 with self.assertRaises(tink.TinkError): 151 verify4.verify_and_decode(compact1, validator) 152 153 # 2 uses the older key. So 1, 2 and 3 can verify the signature, but not 4. 154 compact2 = sign2.sign_and_encode(raw_jwt) 155 self.assertEqual( 156 verify1.verify_and_decode(compact2, validator).issuer(), 'a') 157 self.assertEqual( 158 verify2.verify_and_decode(compact2, validator).issuer(), 'a') 159 self.assertEqual( 160 verify3.verify_and_decode(compact2, validator).issuer(), 'a') 161 with self.assertRaises(tink.TinkError): 162 verify4.verify_and_decode(compact2, validator) 163 164 # 3 uses the newer key. So 2, 3 and 4 can verify the signature, but not 1. 165 compact3 = sign3.sign_and_encode(raw_jwt) 166 with self.assertRaises(tink.TinkError): 167 verify1.verify_and_decode(compact3, validator) 168 self.assertEqual( 169 verify2.verify_and_decode(compact3, validator).issuer(), 'a') 170 self.assertEqual( 171 verify3.verify_and_decode(compact3, validator).issuer(), 'a') 172 self.assertEqual( 173 verify4.verify_and_decode(compact3, validator).issuer(), 'a') 174 175 # 4 uses the newer key. So 2, 3 and 4 can verify the signature, but not 1. 176 compact4 = sign4.sign_and_encode(raw_jwt) 177 with self.assertRaises(tink.TinkError): 178 verify1.verify_and_decode(compact4, validator) 179 self.assertEqual( 180 verify2.verify_and_decode(compact4, validator).issuer(), 'a') 181 self.assertEqual( 182 verify3.verify_and_decode(compact4, validator).issuer(), 'a') 183 self.assertEqual( 184 verify4.verify_and_decode(compact4, validator).issuer(), 'a') 185 186 def test_only_tink_output_prefix_type_encodes_a_kid_header(self): 187 handle = tink.new_keyset_handle(jwt.raw_jwt_es256_template()) 188 sign = handle.primitive(jwt.JwtPublicKeySign) 189 verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 190 191 tink_handle = _change_output_prefix_to_tink(handle) 192 tink_sign = tink_handle.primitive(jwt.JwtPublicKeySign) 193 tink_verify = tink_handle.public_keyset_handle().primitive( 194 jwt.JwtPublicKeyVerify) 195 196 raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True) 197 198 token = sign.sign_and_encode(raw_jwt) 199 token_with_kid = tink_sign.sign_and_encode(raw_jwt) 200 201 _, header, _, _ = _jwt_format.split_signed_compact(token) 202 self.assertNotIn('kid', _json_util.json_loads(header)) 203 204 _, header_with_kid, _, _ = _jwt_format.split_signed_compact(token_with_kid) 205 self.assertIn('kid', _json_util.json_loads(header_with_kid)) 206 207 validator = jwt.new_validator( 208 expected_issuer='issuer', allow_missing_expiration=True) 209 210 verify.verify_and_decode(token, validator) 211 tink_verify.verify_and_decode(token_with_kid, validator) 212 213 other_handle = _change_key_id(tink_handle) 214 other_verify = other_handle.public_keyset_handle().primitive( 215 jwt.JwtPublicKeyVerify) 216 217 verify.verify_and_decode(token_with_kid, validator) 218 # For output prefix type TINK, the kid header is required. 219 with self.assertRaises(tink.TinkError): 220 tink_verify.verify_and_decode(token, validator) 221 # This should fail because value of the kid header is wrong. 222 with self.assertRaises(tink.TinkError): 223 other_verify.verify_and_decode(token_with_kid, validator) 224 225 @parameterized.named_parameters([ 226 ('JWT_ES256_RAW', jwt.raw_jwt_es256_template()), 227 ('JWT_RS256_RAW', jwt.raw_jwt_rs256_2048_f4_template()), 228 ('JWT_PS256_RAW', jwt.raw_jwt_ps256_3072_f4_template()), 229 ]) 230 def test_raw_key_with_custom_kid_header(self, template): 231 # normal key with output prefix RAW 232 handle = tink.new_keyset_handle(template) 233 raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True) 234 validator = jwt.new_validator( 235 expected_issuer='issuer', allow_missing_expiration=True) 236 237 sign = handle.primitive(jwt.JwtPublicKeySign) 238 token = sign.sign_and_encode(raw_jwt) 239 verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 240 verify.verify_and_decode(token, validator) 241 242 _, json_header, _, _ = _jwt_format.split_signed_compact(token) 243 self.assertNotIn('kid', _json_util.json_loads(json_header)) 244 245 # key with a custom_kid set 246 custom_kid_handle = _set_custom_kid(handle, custom_kid=LONG_CUSTOM_KID) 247 custom_kid_sign = custom_kid_handle.primitive(jwt.JwtPublicKeySign) 248 token_with_kid = custom_kid_sign.sign_and_encode(raw_jwt) 249 custom_kid_verify = custom_kid_handle.public_keyset_handle().primitive( 250 jwt.JwtPublicKeyVerify) 251 custom_kid_verify.verify_and_decode(token_with_kid, validator) 252 253 _, header_with_kid, _, _ = _jwt_format.split_signed_compact(token_with_kid) 254 self.assertEqual(_json_util.json_loads(header_with_kid)['kid'], 255 LONG_CUSTOM_KID) 256 257 # The primitive with a custom_kid set accepts tokens without kid header. 258 custom_kid_verify.verify_and_decode(token, validator) 259 260 # The primitive without a custom_kid set ignores the kid header. 261 verify.verify_and_decode(token_with_kid, validator) 262 263 # key with a different custom_kid set 264 other_handle = _set_custom_kid(handle, custom_kid='other kid') 265 other_verify = other_handle.public_keyset_handle().primitive( 266 jwt.JwtPublicKeyVerify) 267 # Fails because the kid value do not match. 268 with self.assertRaises(tink.TinkError): 269 other_verify.verify_and_decode(token_with_kid, validator) 270 271 tink_handle = _change_output_prefix_to_tink(custom_kid_handle) 272 tink_sign = tink_handle.primitive(jwt.JwtPublicKeySign) 273 tink_verify = tink_handle.public_keyset_handle().primitive( 274 jwt.JwtPublicKeyVerify) 275 # Having custom_kid set with output prefix TINK is not allowed. 276 with self.assertRaises(tink.TinkError): 277 tink_sign.sign_and_encode(raw_jwt) 278 with self.assertRaises(tink.TinkError): 279 tink_verify.verify_and_decode(token, validator) 280 with self.assertRaises(tink.TinkError): 281 tink_verify.verify_and_decode(token_with_kid, validator) 282 283 def test_legacy_template_fails(self): 284 template = keyset_builder.legacy_template(jwt.jwt_es256_template()) 285 builder = keyset_builder.new_keyset_builder() 286 key_id = builder.add_new_key(template) 287 builder.set_primary_key(key_id) 288 handle = builder.keyset_handle() 289 with self.assertRaises(tink.TinkError): 290 handle.primitive(jwt.JwtPublicKeySign) 291 with self.assertRaises(tink.TinkError): 292 handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 293 294 def test_legacy_non_primary_key_fails(self): 295 builder = keyset_builder.new_keyset_builder() 296 old_template = keyset_builder.legacy_template(jwt.jwt_es256_template()) 297 _ = builder.add_new_key(old_template) 298 current_key_id = builder.add_new_key(jwt.jwt_es256_template()) 299 builder.set_primary_key(current_key_id) 300 handle = builder.keyset_handle() 301 with self.assertRaises(tink.TinkError): 302 handle.primitive(jwt.JwtPublicKeySign) 303 with self.assertRaises(tink.TinkError): 304 handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify) 305 306 def test_jwt_mac_from_keyset_without_primary_fails(self): 307 builder = keyset_builder.new_keyset_builder() 308 builder.add_new_key(jwt.jwt_es256_template()) 309 with self.assertRaises(tink.TinkError): 310 builder.keyset_handle() 311 312 313if __name__ == '__main__': 314 absltest.main() 315