xref: /aosp_15_r20/external/tink/testing/cross_language/jwt_kid_test.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Cross-language tests for the "kid" header set by JWT primitives."""
15
16import base64
17import json
18from typing import Optional
19
20from absl.testing import absltest
21from absl.testing import parameterized
22import tink
23from tink import jwt
24
25from tink.proto import jwt_ecdsa_pb2
26from tink.proto import jwt_hmac_pb2
27from tink.proto import jwt_rsa_ssa_pkcs1_pb2
28from tink.proto import jwt_rsa_ssa_pss_pb2
29from tink.proto import tink_pb2
30from util import testing_servers
31from util import utilities
32
33SUPPORTED_LANGUAGES = testing_servers.SUPPORTED_LANGUAGES_BY_PRIMITIVE['jwt']
34
35
36def setUpModule():
37  jwt.register_jwt_mac()
38  jwt.register_jwt_signature()
39  testing_servers.start('jwt')
40
41
42def tearDownModule():
43  testing_servers.stop()
44
45
46def base64_decode(encoded_data: bytes) -> bytes:
47  padded_encoded_data = encoded_data + b'==='
48  return base64.urlsafe_b64decode(padded_encoded_data)
49
50
51def decode_kid(compact: str) -> Optional[str]:
52  encoded_header, _, _ = compact.encode('utf8').split(b'.')
53  json_header = base64_decode(encoded_header)
54  header = json.loads(json_header)
55  return header.get('kid', None)
56
57
58def generate_jwt_mac_keyset_with_custom_kid(
59    template_name: str, custom_kid: str) -> tink_pb2.Keyset:
60  key_template = utilities.KEY_TEMPLATE[template_name]
61  keyset_handle = tink.new_keyset_handle(key_template)
62  # parse key_data.value, set custom_kid and serialize
63  key_data_value = keyset_handle._keyset.key[0].key_data.value
64  if template_name.startswith('JWT_HS256'):
65    hmac_key = jwt_hmac_pb2.JwtHmacKey.FromString(key_data_value)
66    hmac_key.custom_kid.value = custom_kid
67    key_data_value = hmac_key.SerializeToString()
68  else:
69    raise ValueError('unknown alg')
70  keyset_handle._keyset.key[0].key_data.value = key_data_value
71  return keyset_handle._keyset
72
73
74def generate_jwt_signature_keyset_with_custom_kid(
75    template_name: str, custom_kid: str) -> tink_pb2.Keyset:
76  key_template = utilities.KEY_TEMPLATE[template_name]
77  keyset_handle = tink.new_keyset_handle(key_template)
78  # parse key_data.value, set custom_kid and serialize
79  key_data_value = keyset_handle._keyset.key[0].key_data.value
80  if template_name.startswith('JWT_ES256'):
81    private_key = jwt_ecdsa_pb2.JwtEcdsaPrivateKey.FromString(key_data_value)
82    private_key.public_key.custom_kid.value = custom_kid
83    key_data_value = private_key.SerializeToString()
84  elif template_name.startswith('JWT_RS256'):
85    private_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PrivateKey.FromString(
86        key_data_value)
87    private_key.public_key.custom_kid.value = custom_kid
88    key_data_value = private_key.SerializeToString()
89  elif template_name.startswith('JWT_PS256'):
90    private_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPrivateKey.FromString(
91        key_data_value)
92    private_key.public_key.custom_kid.value = custom_kid
93    key_data_value = private_key.SerializeToString()
94  else:
95    raise ValueError('unknown template name')
96  keyset_handle._keyset.key[0].key_data.value = key_data_value
97  keyset = keyset_handle._keyset
98  return keyset
99
100
101class JwtKidTest(parameterized.TestCase):
102  """Tests that all JWT primitives consistently add a "kid" header to tokens."""
103
104  @parameterized.parameters(['JWT_HS256'])
105  def test_jwt_mac_sets_kid_for_tink_templates(self, template_name):
106    key_template = utilities.KEY_TEMPLATE[template_name]
107    keyset = testing_servers.new_keyset('cc', key_template)
108    raw_jwt = jwt.new_raw_jwt(without_expiration=True)
109    for lang in SUPPORTED_LANGUAGES:
110      jwt_mac = testing_servers.remote_primitive(lang, keyset, jwt.JwtMac)
111      compact = jwt_mac.compute_mac_and_encode(raw_jwt)
112      self.assertIsNotNone(decode_kid(compact))
113
114  @parameterized.parameters(['JWT_HS256_RAW'])
115  def test_jwt_mac_does_not_sets_kid_for_raw_templates(self, template_name):
116    key_template = utilities.KEY_TEMPLATE[template_name]
117    keyset = testing_servers.new_keyset('cc', key_template)
118    raw_jwt = jwt.new_raw_jwt(without_expiration=True)
119    for lang in SUPPORTED_LANGUAGES:
120      jwt_mac = testing_servers.remote_primitive(lang, keyset, jwt.JwtMac)
121      compact = jwt_mac.compute_mac_and_encode(raw_jwt)
122      self.assertIsNone(decode_kid(compact))
123
124  @parameterized.parameters(
125      ['JWT_ES256', 'JWT_RS256_2048_F4', 'JWT_PS256_2048_F4'])
126  def test_jwt_public_key_sign_sets_kid_for_tink_templates(self, template_name):
127    key_template = utilities.KEY_TEMPLATE[template_name]
128    keyset = testing_servers.new_keyset('cc', key_template)
129    raw_jwt = jwt.new_raw_jwt(without_expiration=True)
130    supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
131        template_name]
132    for lang in supported_langs:
133      jwt_sign = testing_servers.remote_primitive(lang, keyset,
134                                                  jwt.JwtPublicKeySign)
135      compact = jwt_sign.sign_and_encode(raw_jwt)
136      self.assertIsNotNone(decode_kid(compact))
137
138  @parameterized.parameters(
139      ['JWT_ES256_RAW', 'JWT_RS256_2048_F4_RAW', 'JWT_PS256_2048_F4_RAW'])
140  def test_jwt_public_key_sign_does_not_sets_kid_for_raw_templates(
141      self, template_name):
142    key_template = utilities.KEY_TEMPLATE[template_name]
143    keyset = testing_servers.new_keyset('cc', key_template)
144    raw_jwt = jwt.new_raw_jwt(without_expiration=True)
145    supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
146        template_name]
147    for lang in supported_langs:
148      jwt_sign = testing_servers.remote_primitive(lang, keyset,
149                                                  jwt.JwtPublicKeySign)
150      compact = jwt_sign.sign_and_encode(raw_jwt)
151      self.assertIsNone(decode_kid(compact))
152
153  @parameterized.parameters(['JWT_HS256_RAW'])
154  def test_jwt_mac_sets_custom_kid_for_raw_keys(self, template_name):
155    keyset = generate_jwt_mac_keyset_with_custom_kid(
156        template_name=template_name, custom_kid='my kid')
157    raw_jwt = jwt.new_raw_jwt(without_expiration=True)
158    for lang in SUPPORTED_LANGUAGES:
159      jwt_mac = testing_servers.remote_primitive(lang,
160                                                 keyset.SerializeToString(),
161                                                 jwt.JwtMac)
162      compact = jwt_mac.compute_mac_and_encode(raw_jwt)
163      self.assertEqual(decode_kid(compact), 'my kid')
164
165  @parameterized.parameters(['JWT_HS256'])
166  def test_jwt_mac_fails_for_tink_keys_with_custom_kid(self, template_name):
167    keyset = generate_jwt_mac_keyset_with_custom_kid(
168        template_name=template_name, custom_kid='my kid')
169    raw_jwt = jwt.new_raw_jwt(without_expiration=True)
170    for lang in SUPPORTED_LANGUAGES:
171      with self.assertRaises(
172          tink.TinkError,
173          msg=('%s supports JWT mac keys with TINK output prefix type '
174               'and custom_kid set unexpectedly') % lang):
175        jwt_mac = testing_servers.remote_primitive(lang,
176                                                   keyset.SerializeToString(),
177                                                   jwt.JwtMac)
178        jwt_mac.compute_mac_and_encode(raw_jwt)
179
180  @parameterized.parameters(
181      ['JWT_ES256_RAW', 'JWT_RS256_2048_F4_RAW', 'JWT_PS256_2048_F4_RAW'])
182  def test_jwt_public_key_sign_sets_custom_kid_for_raw_keys(
183      self, template_name):
184    keyset = generate_jwt_signature_keyset_with_custom_kid(
185        template_name=template_name, custom_kid='my kid')
186    raw_jwt = jwt.new_raw_jwt(without_expiration=True)
187    supported_langs = utilities.SUPPORTED_LANGUAGES_BY_TEMPLATE_NAME[
188        template_name]
189    for lang in supported_langs:
190      jwt_sign = testing_servers.remote_primitive(lang,
191                                                  keyset.SerializeToString(),
192                                                  jwt.JwtPublicKeySign)
193      compact = jwt_sign.sign_and_encode(raw_jwt)
194      self.assertEqual(decode_kid(compact), 'my kid')
195
196  @parameterized.parameters(
197      ['JWT_ES256', 'JWT_RS256_2048_F4', 'JWT_PS256_2048_F4'])
198  def test_jwt_public_key_sign_fails_for_tink_keys_with_custom_kid(
199      self, template_name):
200    keyset = generate_jwt_signature_keyset_with_custom_kid(
201        template_name=template_name, custom_kid='my kid')
202    raw_jwt = jwt.new_raw_jwt(without_expiration=True)
203    for lang in SUPPORTED_LANGUAGES:
204      with self.assertRaises(
205          tink.TinkError,
206          msg=('%s supports JWT signature keys with TINK output prefix type '
207               'and custom_kid set unexpectedly') % lang):
208        jwt_sign = testing_servers.remote_primitive(lang,
209                                                    keyset.SerializeToString(),
210                                                    jwt.JwtPublicKeySign)
211        jwt_sign.sign_and_encode(raw_jwt)
212
213
214if __name__ == '__main__':
215  absltest.main()
216