xref: /aosp_15_r20/external/tink/python/tink/jwt/_jwk_set_converter.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Convert Tink Keyset with JWT keys from and to JWK sets."""
15
16import io
17import json
18import random
19
20from typing import Dict, List, Optional, Union
21
22from tink.proto import jwt_ecdsa_pb2
23from tink.proto import jwt_rsa_ssa_pkcs1_pb2
24from tink.proto import jwt_rsa_ssa_pss_pb2
25from tink.proto import tink_pb2
26import tink
27from tink import cleartext_keyset_handle
28from tink.jwt import _jwt_format
29
30_JWT_ECDSA_PUBLIC_KEY_TYPE = (
31    'type.googleapis.com/google.crypto.tink.JwtEcdsaPublicKey')
32_JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE = (
33    'type.googleapis.com/google.crypto.tink.JwtRsaSsaPkcs1PublicKey')
34_JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE = (
35    'type.googleapis.com/google.crypto.tink.JwtRsaSsaPssPublicKey')
36
37_ECDSA_PARAMS = {
38    jwt_ecdsa_pb2.ES256: ('ES256', 'P-256'),
39    jwt_ecdsa_pb2.ES384: ('ES384', 'P-384'),
40    jwt_ecdsa_pb2.ES512: ('ES512', 'P-521')
41}
42
43_ECDSA_NAME_TO_ALGORITHM = {
44    alg_name: algorithm for algorithm, (alg_name, _) in _ECDSA_PARAMS.items()
45}
46
47_RSA_SSA_PKCS1_PARAMS = {
48    jwt_rsa_ssa_pkcs1_pb2.RS256: 'RS256',
49    jwt_rsa_ssa_pkcs1_pb2.RS384: 'RS384',
50    jwt_rsa_ssa_pkcs1_pb2.RS512: 'RS512'
51}
52
53_RSA_SSA_PKCS1_NAME_TO_ALGORITHM = {
54    alg_name: algorithm
55    for algorithm, alg_name in _RSA_SSA_PKCS1_PARAMS.items()
56}
57
58_RSA_SSA_PSS_PARAMS = {
59    jwt_rsa_ssa_pss_pb2.PS256: 'PS256',
60    jwt_rsa_ssa_pss_pb2.PS384: 'PS384',
61    jwt_rsa_ssa_pss_pb2.PS512: 'PS512'
62}
63
64_RSA_SSA_PSS_NAME_TO_ALGORITHM = {
65    alg_name: algorithm
66    for algorithm, alg_name in _RSA_SSA_PSS_PARAMS.items()
67}
68
69
70def _base64_encode(data: bytes) -> str:
71  return _jwt_format.base64_encode(data).decode('utf8')
72
73
74def _base64_decode(data: str) -> bytes:
75  return _jwt_format.base64_decode(data.encode('utf8'))
76
77
78def from_public_keyset_handle(keyset_handle: tink.KeysetHandle) -> str:
79  """Converts a Tink KeysetHandle with JWT keys into a Json Web Key (JWK) set.
80
81  JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.txt.
82
83  Disabled keys are skipped.
84
85  Keys with output prefix type "TINK" will include the encoded key ID as "kid"
86  value. Keys with output prefix type "RAW" will not have a "kid" value set.
87
88  Currently, public keys for algorithms ES256, ES384, ES512, RS256, RS384,
89  RS512, PS256, PS384 and PS512 supported.
90
91  Args:
92    keyset_handle: A Tink KeysetHandle that contains JWT Keys.
93
94  Returns:
95    A JWK set, which is a JSON encoded string.
96
97  Raises:
98    TinkError if the keys are not of the expected type, or if they have a
99    ouput prefix type that is not supported.
100  """
101  output_stream = io.BytesIO()
102  writer = tink.BinaryKeysetWriter(output_stream)
103  keyset_handle.write_no_secret(writer)
104  keyset = tink_pb2.Keyset.FromString(output_stream.getvalue())
105
106  keys = []
107  for key in keyset.key:
108    if key.status != tink_pb2.ENABLED:
109      continue
110    if key.key_data.key_material_type != tink_pb2.KeyData.ASYMMETRIC_PUBLIC:
111      raise tink.TinkError('wrong key material type')
112    if key.output_prefix_type not in [tink_pb2.RAW, tink_pb2.TINK]:
113      raise tink.TinkError('unsupported output prefix type')
114    if key.key_data.type_url == _JWT_ECDSA_PUBLIC_KEY_TYPE:
115      keys.append(_convert_jwt_ecdsa_key(key))
116    elif key.key_data.type_url == _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE:
117      keys.append(_convert_jwt_rsa_ssa_pkcs1_key(key))
118    elif key.key_data.type_url == _JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE:
119      keys.append(_convert_jwt_rsa_ssa_pss_key(key))
120    else:
121      raise tink.TinkError('unknown key type: %s' % key.key_data.type_url)
122  return json.dumps({'keys': keys}, separators=(',', ':'))
123
124
125# Deprecated. Use from_public_keyset_handle instead.
126def from_keyset_handle(keyset_handle: tink.KeysetHandle,
127                       key_access: Optional[tink.KeyAccess] = None) -> str:
128  _ = key_access
129  return from_public_keyset_handle(keyset_handle)
130
131
132def _convert_jwt_ecdsa_key(
133    key: tink_pb2.Keyset.Key) -> Dict[str, Union[str, List[str]]]:
134  """Converts a JwtEcdsaPublicKey into a JWK."""
135  ecdsa_public_key = jwt_ecdsa_pb2.JwtEcdsaPublicKey.FromString(
136      key.key_data.value)
137  if ecdsa_public_key.algorithm not in _ECDSA_PARAMS:
138    raise tink.TinkError('unknown ecdsa algorithm')
139  alg, crv = _ECDSA_PARAMS[ecdsa_public_key.algorithm]
140  output = {
141      'kty': 'EC',
142      'crv': crv,
143      'x': _base64_encode(ecdsa_public_key.x),
144      'y': _base64_encode(ecdsa_public_key.y),
145      'use': 'sig',
146      'alg': alg,
147      'key_ops': ['verify'],
148  }
149  kid = _jwt_format.get_kid(key.key_id, key.output_prefix_type)
150  if kid:
151    output['kid'] = kid
152  elif ecdsa_public_key.HasField('custom_kid'):
153    output['kid'] = ecdsa_public_key.custom_kid.value
154  return output
155
156
157def _convert_jwt_rsa_ssa_pkcs1_key(
158    key: tink_pb2.Keyset.Key) -> Dict[str, Union[str, List[str]]]:
159  """Converts a JwtRsaSsaPkcs1PublicKey into a JWK."""
160  public_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PublicKey.FromString(
161      key.key_data.value)
162  if public_key.algorithm not in _RSA_SSA_PKCS1_PARAMS:
163    raise tink.TinkError('unknown RSA SSA PKCS1 algorithm')
164  alg = _RSA_SSA_PKCS1_PARAMS[public_key.algorithm]
165  output = {
166      'kty': 'RSA',
167      'n': _base64_encode(public_key.n),
168      'e': _base64_encode(public_key.e),
169      'use': 'sig',
170      'alg': alg,
171      'key_ops': ['verify'],
172  }
173  kid = _jwt_format.get_kid(key.key_id, key.output_prefix_type)
174  if kid:
175    output['kid'] = kid
176  elif public_key.HasField('custom_kid'):
177    output['kid'] = public_key.custom_kid.value
178  return output
179
180
181def _convert_jwt_rsa_ssa_pss_key(
182    key: tink_pb2.Keyset.Key) -> Dict[str, Union[str, List[str]]]:
183  """Converts a JwtRsaSsaPssPublicKey into a JWK."""
184  public_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPublicKey.FromString(
185      key.key_data.value)
186  if public_key.algorithm not in _RSA_SSA_PSS_PARAMS:
187    raise tink.TinkError('unknown RSA SSA PSS algorithm')
188  alg = _RSA_SSA_PSS_PARAMS[public_key.algorithm]
189  output = {
190      'kty': 'RSA',
191      'n': _base64_encode(public_key.n),
192      'e': _base64_encode(public_key.e),
193      'use': 'sig',
194      'alg': alg,
195      'key_ops': ['verify'],
196  }
197  kid = _jwt_format.get_kid(key.key_id, key.output_prefix_type)
198  if kid:
199    output['kid'] = kid
200  elif public_key.HasField('custom_kid'):
201    output['kid'] = public_key.custom_kid.value
202  return output
203
204
205def _generate_unused_key_id(keyset: tink_pb2.Keyset) -> int:
206  while True:
207    key_id = random.randint(1, 2147483647)
208    if key_id not in {key.key_id for key in keyset.key}:
209      return key_id
210
211
212def to_public_keyset_handle(jwk_set: str) -> tink.KeysetHandle:
213  """Converts a Json Web Key (JWK) set into a Tink KeysetHandle with JWT keys.
214
215  JWK is defined in https://www.rfc-editor.org/rfc/rfc7517.txt.
216
217  All keys are converted into Tink keys with output prefix type "RAW".
218
219  Currently, public keys for algorithms ES256, ES384, ES512, RS256, RS384,
220  RS512, PS256, PS384 and PS512 supported.
221
222  Args:
223    jwk_set: A JWK set, which is a JSON encoded string.
224
225  Returns:
226    A tink.KeysetHandle.
227
228  Raises:
229    TinkError if the key cannot be converted.
230  """
231  try:
232    keys_dict = json.loads(jwk_set)
233  except json.decoder.JSONDecodeError as e:
234    raise tink.TinkError('error parsing JWK set: %s' % e.msg)
235  if 'keys' not in keys_dict:
236    raise tink.TinkError('invalid JWK set: keys not found')
237  proto_keyset = tink_pb2.Keyset()
238  for key in keys_dict['keys']:
239    if 'alg' not in key:
240      raise tink.TinkError('invalid JWK: alg not found')
241    alg = key['alg']
242    if alg.startswith('ES'):
243      proto_key = _convert_to_ecdsa_key(key)
244    elif alg.startswith('RS'):
245      proto_key = _convert_to_rsa_ssa_pkcs1_key(key)
246    elif alg.startswith('PS'):
247      proto_key = _convert_to_rsa_ssa_pss_key(key)
248    else:
249      raise tink.TinkError('unknown alg')
250    new_id = _generate_unused_key_id(proto_keyset)
251    proto_key.key_id = new_id
252    proto_keyset.key.append(proto_key)
253    # JWK sets do not really have a primary key (see RFC 7517, Section 5.1).
254    # To verify signature, it also does not matter which key is primary. We
255    # simply set it to the last key.
256    proto_keyset.primary_key_id = new_id
257  return cleartext_keyset_handle.from_keyset(proto_keyset)
258
259
260# Deprecated. Use to_public_keyset_handle instead.
261def to_keyset_handle(
262    jwk_set: str,
263    key_access: Optional[tink.KeyAccess] = None) -> tink.KeysetHandle:
264  _ = key_access
265  return to_public_keyset_handle(jwk_set)
266
267
268def _validate_use_and_key_ops(key: Dict[str, Union[str, List[str]]]):
269  """Checks that 'key_ops' and 'use' have the right values if present."""
270  if 'key_ops' in key:
271    key_ops = key['key_ops']
272    if len(key_ops) != 1 or key_ops[0] != 'verify':
273      raise tink.TinkError('invalid key_ops')
274  if 'use' in key and key['use'] != 'sig':
275    raise tink.TinkError('invalid use')
276
277
278def _convert_to_ecdsa_key(
279    key: Dict[str, Union[str, List[str]]]) -> tink_pb2.Keyset.Key:
280  """Converts a EC Json Web Key (JWK) into a tink_pb2.Keyset.Key."""
281  ecdsa_public_key = jwt_ecdsa_pb2.JwtEcdsaPublicKey()
282  algorithm = _ECDSA_NAME_TO_ALGORITHM.get(key['alg'], None)
283  if not algorithm:
284    raise tink.TinkError('unknown ECDSA algorithm')
285  if key.get('kty', None) != 'EC':
286    raise tink.TinkError('invalid kty')
287  _, crv = _ECDSA_PARAMS[algorithm]
288  if key.get('crv', None) != crv:
289    raise tink.TinkError('invalid crv')
290  _validate_use_and_key_ops(key)
291  if 'd' in key:
292    raise tink.TinkError('cannot convert private ECDSA key')
293  ecdsa_public_key.algorithm = algorithm
294  ecdsa_public_key.x = _base64_decode(key['x'])
295  ecdsa_public_key.y = _base64_decode(key['y'])
296  if 'kid' in key:
297    ecdsa_public_key.custom_kid.value = key['kid']
298  proto_key = tink_pb2.Keyset.Key()
299  proto_key.key_data.type_url = _JWT_ECDSA_PUBLIC_KEY_TYPE
300  proto_key.key_data.value = ecdsa_public_key.SerializeToString()
301  proto_key.key_data.key_material_type = tink_pb2.KeyData.ASYMMETRIC_PUBLIC
302  proto_key.output_prefix_type = tink_pb2.RAW
303  proto_key.status = tink_pb2.ENABLED
304  return proto_key
305
306
307def _convert_to_rsa_ssa_pkcs1_key(
308    key: Dict[str, Union[str, List[str]]]) -> tink_pb2.Keyset.Key:
309  """Converts a JWK into a JwtEcdsaPublicKey."""
310  public_key = jwt_rsa_ssa_pkcs1_pb2.JwtRsaSsaPkcs1PublicKey()
311  algorithm = _RSA_SSA_PKCS1_NAME_TO_ALGORITHM.get(key['alg'], None)
312  if not algorithm:
313    raise tink.TinkError('unknown RSA SSA PKCS1 algorithm')
314  if key.get('kty', None) != 'RSA':
315    raise tink.TinkError('invalid kty')
316  _validate_use_and_key_ops(key)
317  if ('p' in key or 'q' in key or 'dp' in key or 'dq' in key or 'd' in key or
318      'qi' in key):
319    raise tink.TinkError('importing RSA private keys is not implemented')
320  public_key.algorithm = algorithm
321  public_key.n = _base64_decode(key['n'])
322  public_key.e = _base64_decode(key['e'])
323  if 'kid' in key:
324    public_key.custom_kid.value = key['kid']
325  proto_key = tink_pb2.Keyset.Key()
326  proto_key.key_data.type_url = _JWT_RSA_SSA_PKCS1_PUBLIC_KEY_TYPE
327  proto_key.key_data.value = public_key.SerializeToString()
328  proto_key.key_data.key_material_type = tink_pb2.KeyData.ASYMMETRIC_PUBLIC
329  proto_key.output_prefix_type = tink_pb2.RAW
330  proto_key.status = tink_pb2.ENABLED
331  return proto_key
332
333
334def _convert_to_rsa_ssa_pss_key(
335    key: Dict[str, Union[str, List[str]]]) -> tink_pb2.Keyset.Key:
336  """Converts a JWK into a JwtEcdsaPublicKey."""
337  public_key = jwt_rsa_ssa_pss_pb2.JwtRsaSsaPssPublicKey()
338  algorithm = _RSA_SSA_PSS_NAME_TO_ALGORITHM.get(key['alg'], None)
339  if not algorithm:
340    raise tink.TinkError('unknown RSA SSA PSS algorithm')
341  if key.get('kty', None) != 'RSA':
342    raise tink.TinkError('invalid kty')
343  _validate_use_and_key_ops(key)
344  if ('p' in key or 'q' in key or 'dp' in key or 'dq' in key or 'd' in key or
345      'qi' in key):
346    raise tink.TinkError('importing RSA private keys is not implemented')
347  public_key.algorithm = algorithm
348  public_key.n = _base64_decode(key['n'])
349  public_key.e = _base64_decode(key['e'])
350  if 'kid' in key:
351    public_key.custom_kid.value = key['kid']
352  proto_key = tink_pb2.Keyset.Key()
353  proto_key.key_data.type_url = _JWT_RSA_SSA_PSS_PUBLIC_KEY_TYPE
354  proto_key.key_data.value = public_key.SerializeToString()
355  proto_key.key_data.key_material_type = tink_pb2.KeyData.ASYMMETRIC_PUBLIC
356  proto_key.output_prefix_type = tink_pb2.RAW
357  proto_key.status = tink_pb2.ENABLED
358  return proto_key
359