xref: /aosp_15_r20/external/tink/python/tink/jwt/_jwt_signature_wrappers_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tink.python.tink.jwt._jwt_signature_wrappers_test."""
15
16import io
17
18
19from absl.testing import absltest
20from absl.testing import parameterized
21
22from tink.proto import jwt_ecdsa_pb2
23from tink.proto import jwt_rsa_ssa_pkcs1_pb2
24from tink.proto import jwt_rsa_ssa_pss_pb2
25from tink.proto import tink_pb2
26import tink
27from tink import cleartext_keyset_handle
28from tink import jwt
29from tink.jwt import _json_util
30from tink.jwt import _jwt_format
31
32from tink.testing import keyset_builder
33
34
35LONG_CUSTOM_KID = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit'
36
37
38def setUpModule():
39  jwt.register_jwt_signature()
40
41
42def _change_key_id(keyset_handle: tink.KeysetHandle) -> tink.KeysetHandle:
43  """Changes the key id of the first key and makes it primary."""
44  buffer = io.BytesIO()
45  cleartext_keyset_handle.write(
46      tink.BinaryKeysetWriter(buffer), keyset_handle)
47  keyset = tink_pb2.Keyset.FromString(buffer.getvalue())
48  # XOR the key id with an arbitrary 32-bit string to get a new key id.
49  new_key_id = keyset.key[0].key_id ^ 0xdeadbeef
50  keyset.key[0].key_id = new_key_id
51  keyset.primary_key_id = new_key_id
52  return cleartext_keyset_handle.from_keyset(keyset)
53
54
55def _change_output_prefix_to_tink(
56    keyset_handle: tink.KeysetHandle) -> tink.KeysetHandle:
57  """Changes the output prefix type of the first key to TINK."""
58  buffer = io.BytesIO()
59  cleartext_keyset_handle.write(
60      tink.BinaryKeysetWriter(buffer), keyset_handle)
61  keyset = tink_pb2.Keyset.FromString(buffer.getvalue())
62  keyset.key[0].output_prefix_type = tink_pb2.TINK
63  return cleartext_keyset_handle.from_keyset(keyset)
64
65
66def _set_custom_kid(keyset_handle: tink.KeysetHandle,
67                    custom_kid: str) -> tink.KeysetHandle:
68  """Sets the custom_kid field of the first key."""
69  buffer = io.BytesIO()
70  cleartext_keyset_handle.write(
71      tink.BinaryKeysetWriter(buffer), keyset_handle)
72  keyset = tink_pb2.Keyset.FromString(buffer.getvalue())
73  if keyset.key[0].key_data.type_url.endswith('JwtEcdsaPrivateKey'):
74    jwt_ecdsa_key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(
75        keyset.key[0].key_data.value)
76    jwt_ecdsa_key.public_key.custom_kid.value = custom_kid
77    keyset.key[0].key_data.value = jwt_ecdsa_key.SerializeToString()
78  elif keyset.key[0].key_data.type_url.endswith('JwtRsaSsaPkcs1PrivateKey'):
79    rsa_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PrivateKey.FromString(
80        keyset.key[0].key_data.value)
81    rsa_key.public_key.custom_kid.value = custom_kid
82    keyset.key[0].key_data.value = rsa_key.SerializeToString()
83  elif keyset.key[0].key_data.type_url.endswith('JwtRsaSsaPssPrivateKey'):
84    rsa_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPrivateKey.FromString(
85        keyset.key[0].key_data.value)
86    rsa_key.public_key.custom_kid.value = custom_kid
87    keyset.key[0].key_data.value = rsa_key.SerializeToString()
88  else:
89    raise tink.TinkError('unknown key type')
90  return cleartext_keyset_handle.from_keyset(keyset)
91
92
93class JwtSignatureWrapperTest(parameterized.TestCase):
94
95  def test_interesting_error(self):
96    private_handle = tink.new_keyset_handle(jwt.jwt_es256_template())
97    sign = private_handle.primitive(jwt.JwtPublicKeySign)
98    verify = private_handle.public_keyset_handle().primitive(
99        jwt.JwtPublicKeyVerify)
100    raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True)
101    compact = sign.sign_and_encode(raw_jwt)
102    with self.assertRaisesRegex(jwt.JwtInvalidError,
103                                'invalid JWT; expected issuer'):
104      verify.verify_and_decode(compact, jwt.new_validator(
105          expected_issuer='unknown', allow_missing_expiration=True))
106
107  @parameterized.parameters([
108      (jwt.raw_jwt_es256_template(), jwt.raw_jwt_es256_template()),
109      (jwt.raw_jwt_es256_template(), jwt.jwt_es256_template()),
110      (jwt.jwt_es256_template(), jwt.raw_jwt_es256_template()),
111      (jwt.jwt_es256_template(), jwt.jwt_es256_template()),
112  ])
113  def test_key_rotation(self, old_key_tmpl, new_key_tmpl):
114    builder = keyset_builder.new_keyset_builder()
115    older_key_id = builder.add_new_key(old_key_tmpl)
116
117    builder.set_primary_key(older_key_id)
118    handle1 = builder.keyset_handle()
119    sign1 = handle1.primitive(jwt.JwtPublicKeySign)
120    verify1 = handle1.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
121
122    newer_key_id = builder.add_new_key(new_key_tmpl)
123    handle2 = builder.keyset_handle()
124    sign2 = handle2.primitive(jwt.JwtPublicKeySign)
125    verify2 = handle2.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
126
127    builder.set_primary_key(newer_key_id)
128    handle3 = builder.keyset_handle()
129    sign3 = handle3.primitive(jwt.JwtPublicKeySign)
130    verify3 = handle3.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
131
132    builder.disable_key(older_key_id)
133    handle4 = builder.keyset_handle()
134    sign4 = handle4.primitive(jwt.JwtPublicKeySign)
135    verify4 = handle4.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
136
137    raw_jwt = jwt.new_raw_jwt(issuer='a', without_expiration=True)
138    validator = jwt.new_validator(
139        expected_issuer='a', allow_missing_expiration=True)
140
141    self.assertNotEqual(older_key_id, newer_key_id)
142    # 1 uses the older key. So 1, 2 and 3 can verify the signature, but not 4.
143    compact1 = sign1.sign_and_encode(raw_jwt)
144    self.assertEqual(
145        verify1.verify_and_decode(compact1, validator).issuer(), 'a')
146    self.assertEqual(
147        verify2.verify_and_decode(compact1, validator).issuer(), 'a')
148    self.assertEqual(
149        verify3.verify_and_decode(compact1, validator).issuer(), 'a')
150    with self.assertRaises(tink.TinkError):
151      verify4.verify_and_decode(compact1, validator)
152
153    # 2 uses the older key. So 1, 2 and 3 can verify the signature, but not 4.
154    compact2 = sign2.sign_and_encode(raw_jwt)
155    self.assertEqual(
156        verify1.verify_and_decode(compact2, validator).issuer(), 'a')
157    self.assertEqual(
158        verify2.verify_and_decode(compact2, validator).issuer(), 'a')
159    self.assertEqual(
160        verify3.verify_and_decode(compact2, validator).issuer(), 'a')
161    with self.assertRaises(tink.TinkError):
162      verify4.verify_and_decode(compact2, validator)
163
164    # 3 uses the newer key. So 2, 3 and 4 can verify the signature, but not 1.
165    compact3 = sign3.sign_and_encode(raw_jwt)
166    with self.assertRaises(tink.TinkError):
167      verify1.verify_and_decode(compact3, validator)
168    self.assertEqual(
169        verify2.verify_and_decode(compact3, validator).issuer(), 'a')
170    self.assertEqual(
171        verify3.verify_and_decode(compact3, validator).issuer(), 'a')
172    self.assertEqual(
173        verify4.verify_and_decode(compact3, validator).issuer(), 'a')
174
175    # 4 uses the newer key. So 2, 3 and 4 can verify the signature, but not 1.
176    compact4 = sign4.sign_and_encode(raw_jwt)
177    with self.assertRaises(tink.TinkError):
178      verify1.verify_and_decode(compact4, validator)
179    self.assertEqual(
180        verify2.verify_and_decode(compact4, validator).issuer(), 'a')
181    self.assertEqual(
182        verify3.verify_and_decode(compact4, validator).issuer(), 'a')
183    self.assertEqual(
184        verify4.verify_and_decode(compact4, validator).issuer(), 'a')
185
186  def test_only_tink_output_prefix_type_encodes_a_kid_header(self):
187    handle = tink.new_keyset_handle(jwt.raw_jwt_es256_template())
188    sign = handle.primitive(jwt.JwtPublicKeySign)
189    verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
190
191    tink_handle = _change_output_prefix_to_tink(handle)
192    tink_sign = tink_handle.primitive(jwt.JwtPublicKeySign)
193    tink_verify = tink_handle.public_keyset_handle().primitive(
194        jwt.JwtPublicKeyVerify)
195
196    raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True)
197
198    token = sign.sign_and_encode(raw_jwt)
199    token_with_kid = tink_sign.sign_and_encode(raw_jwt)
200
201    _, header, _, _ = _jwt_format.split_signed_compact(token)
202    self.assertNotIn('kid', _json_util.json_loads(header))
203
204    _, header_with_kid, _, _ = _jwt_format.split_signed_compact(token_with_kid)
205    self.assertIn('kid', _json_util.json_loads(header_with_kid))
206
207    validator = jwt.new_validator(
208        expected_issuer='issuer', allow_missing_expiration=True)
209
210    verify.verify_and_decode(token, validator)
211    tink_verify.verify_and_decode(token_with_kid, validator)
212
213    other_handle = _change_key_id(tink_handle)
214    other_verify = other_handle.public_keyset_handle().primitive(
215        jwt.JwtPublicKeyVerify)
216
217    verify.verify_and_decode(token_with_kid, validator)
218    # For output prefix type TINK, the kid header is required.
219    with self.assertRaises(tink.TinkError):
220      tink_verify.verify_and_decode(token, validator)
221    # This should fail because value of the kid header is wrong.
222    with self.assertRaises(tink.TinkError):
223      other_verify.verify_and_decode(token_with_kid, validator)
224
225  @parameterized.named_parameters([
226      ('JWT_ES256_RAW', jwt.raw_jwt_es256_template()),
227      ('JWT_RS256_RAW', jwt.raw_jwt_rs256_2048_f4_template()),
228      ('JWT_PS256_RAW', jwt.raw_jwt_ps256_3072_f4_template()),
229  ])
230  def test_raw_key_with_custom_kid_header(self, template):
231    # normal key with output prefix RAW
232    handle = tink.new_keyset_handle(template)
233    raw_jwt = jwt.new_raw_jwt(issuer='issuer', without_expiration=True)
234    validator = jwt.new_validator(
235        expected_issuer='issuer', allow_missing_expiration=True)
236
237    sign = handle.primitive(jwt.JwtPublicKeySign)
238    token = sign.sign_and_encode(raw_jwt)
239    verify = handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
240    verify.verify_and_decode(token, validator)
241
242    _, json_header, _, _ = _jwt_format.split_signed_compact(token)
243    self.assertNotIn('kid', _json_util.json_loads(json_header))
244
245    # key with a custom_kid set
246    custom_kid_handle = _set_custom_kid(handle, custom_kid=LONG_CUSTOM_KID)
247    custom_kid_sign = custom_kid_handle.primitive(jwt.JwtPublicKeySign)
248    token_with_kid = custom_kid_sign.sign_and_encode(raw_jwt)
249    custom_kid_verify = custom_kid_handle.public_keyset_handle().primitive(
250        jwt.JwtPublicKeyVerify)
251    custom_kid_verify.verify_and_decode(token_with_kid, validator)
252
253    _, header_with_kid, _, _ = _jwt_format.split_signed_compact(token_with_kid)
254    self.assertEqual(_json_util.json_loads(header_with_kid)['kid'],
255                     LONG_CUSTOM_KID)
256
257    # The primitive with a custom_kid set accepts tokens without kid header.
258    custom_kid_verify.verify_and_decode(token, validator)
259
260    # The primitive without a custom_kid set ignores the kid header.
261    verify.verify_and_decode(token_with_kid, validator)
262
263    # key with a different custom_kid set
264    other_handle = _set_custom_kid(handle, custom_kid='other kid')
265    other_verify = other_handle.public_keyset_handle().primitive(
266        jwt.JwtPublicKeyVerify)
267    # Fails because the kid value do not match.
268    with self.assertRaises(tink.TinkError):
269      other_verify.verify_and_decode(token_with_kid, validator)
270
271    tink_handle = _change_output_prefix_to_tink(custom_kid_handle)
272    tink_sign = tink_handle.primitive(jwt.JwtPublicKeySign)
273    tink_verify = tink_handle.public_keyset_handle().primitive(
274        jwt.JwtPublicKeyVerify)
275    # Having custom_kid set with output prefix TINK is not allowed.
276    with self.assertRaises(tink.TinkError):
277      tink_sign.sign_and_encode(raw_jwt)
278    with self.assertRaises(tink.TinkError):
279      tink_verify.verify_and_decode(token, validator)
280    with self.assertRaises(tink.TinkError):
281      tink_verify.verify_and_decode(token_with_kid, validator)
282
283  def test_legacy_template_fails(self):
284    template = keyset_builder.legacy_template(jwt.jwt_es256_template())
285    builder = keyset_builder.new_keyset_builder()
286    key_id = builder.add_new_key(template)
287    builder.set_primary_key(key_id)
288    handle = builder.keyset_handle()
289    with self.assertRaises(tink.TinkError):
290      handle.primitive(jwt.JwtPublicKeySign)
291    with self.assertRaises(tink.TinkError):
292      handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
293
294  def test_legacy_non_primary_key_fails(self):
295    builder = keyset_builder.new_keyset_builder()
296    old_template = keyset_builder.legacy_template(jwt.jwt_es256_template())
297    _ = builder.add_new_key(old_template)
298    current_key_id = builder.add_new_key(jwt.jwt_es256_template())
299    builder.set_primary_key(current_key_id)
300    handle = builder.keyset_handle()
301    with self.assertRaises(tink.TinkError):
302      handle.primitive(jwt.JwtPublicKeySign)
303    with self.assertRaises(tink.TinkError):
304      handle.public_keyset_handle().primitive(jwt.JwtPublicKeyVerify)
305
306  def test_jwt_mac_from_keyset_without_primary_fails(self):
307    builder = keyset_builder.new_keyset_builder()
308    builder.add_new_key(jwt.jwt_es256_template())
309    with self.assertRaises(tink.TinkError):
310      builder.keyset_handle()
311
312
313if __name__ == '__main__':
314  absltest.main()
315