xref: /aosp_15_r20/external/tink/python/tink/jwt/_jwt_format.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Functions that help to serialize and deserialize from/to the JWT format."""
15
16import base64
17import binascii
18import struct
19from typing import Any, Optional, Tuple
20
21from tink.proto import tink_pb2
22from tink.jwt import _json_util
23from tink.jwt import _jwt_error
24from tink.jwt import _raw_jwt
25
26_VALID_ALGORITHMS = frozenset({
27    'HS256', 'HS384', 'HS512', 'ES256', 'ES384', 'ES512', 'RS256', 'RS384',
28    'RS384', 'RS512', 'PS256', 'PS384', 'PS512'
29})
30
31
32def base64_encode(data: bytes) -> bytes:
33  """Does a URL-safe base64 encoding without padding."""
34  return base64.urlsafe_b64encode(data).rstrip(b'=')
35
36
37def _is_valid_urlsafe_base64_char(c: int) -> bool:
38  if c >= ord('a') and c <= ord('z'):
39    return True
40  if c >= ord('A') and c <= ord('Z'):
41    return True
42  if c >= ord('0') and c <= ord('9'):
43    return True
44  if c == ord('-') or c == ord('_'):
45    return True
46  return False
47
48
49def base64_decode(encoded_data: bytes) -> bytes:
50  """Does a URL-safe base64 decoding without padding."""
51  # base64.urlsafe_b64decode ignores all non-base64 chars. We don't want that.
52  for c in encoded_data:
53    if not _is_valid_urlsafe_base64_char(c):
54      raise _jwt_error.JwtInvalidError('invalid base64 encoding')
55  # base64.urlsafe_b64decode requires padding, but does not mind too much
56  # padding. So we simply add the maximum amount of padding needed.
57  padded_encoded_data = encoded_data + b'==='
58  try:
59    return base64.urlsafe_b64decode(padded_encoded_data)
60  except binascii.Error:
61    # Throws when the length of encoded_data is (4*i + 1) for some i
62    raise _jwt_error.JwtInvalidError('invalid base64 encoding')
63
64
65def _validate_algorithm(algorithm: str) -> None:
66  if algorithm not in _VALID_ALGORITHMS:
67    raise _jwt_error.JwtInvalidError('Invalid algorithm %s' % algorithm)
68
69
70def encode_header(json_header: str) -> bytes:
71  try:
72    return base64_encode(json_header.encode('utf8'))
73  except UnicodeEncodeError:
74    raise _jwt_error.JwtInvalidError('invalid token')
75
76
77def decode_header(encoded_header: bytes) -> str:
78  try:
79    return base64_decode(encoded_header).decode('utf8')
80  except UnicodeDecodeError:
81    raise _jwt_error.JwtInvalidError('invalid token')
82
83
84def encode_payload(json_payload: str) -> bytes:
85  """Encodes the payload into compact form."""
86  try:
87    return base64_encode(json_payload.encode('utf8'))
88  except UnicodeEncodeError:
89    raise _jwt_error.JwtInvalidError('invalid token')
90
91
92def decode_payload(encoded_payload: bytes) -> str:
93  """Decodes the payload from compact form."""
94  try:
95    return base64_decode(encoded_payload).decode('utf8')
96  except UnicodeDecodeError:
97    raise _jwt_error.JwtInvalidError('invalid token')
98
99
100def encode_signature(signature: bytes) -> bytes:
101  """Encodes the signature."""
102  return base64_encode(signature)
103
104
105def decode_signature(encoded_signature: bytes) -> bytes:
106  """Decodes the signature."""
107  return base64_decode(encoded_signature)
108
109
110def create_header(algorithm: str, type_header: Optional[str],
111                  kid: Optional[str]) -> bytes:
112  _validate_algorithm(algorithm)
113  header = {}
114  if kid:
115    header['kid'] = kid
116  header['alg'] = algorithm
117  if type_header:
118    header['typ'] = type_header
119  return encode_header(_json_util.json_dumps(header))
120
121
122def get_kid(key_id: int, prefix: tink_pb2.OutputPrefixType) -> Optional[str]:
123  """Returns the encoded key_id, or None."""
124  if prefix == tink_pb2.RAW:
125    return None
126  if prefix == tink_pb2.TINK:
127    if key_id < 0 or key_id > 2**32:
128      raise _jwt_error.JwtInvalidError('invalid key_id')
129    return base64_encode(struct.pack('>L', key_id)).decode('utf8')
130  raise _jwt_error.JwtInvalidError('unexpected output prefix type')
131
132
133def split_signed_compact(signed_compact: str) -> Tuple[bytes, str, str, bytes]:
134  """Splits a signed compact into its parts.
135
136  Args:
137    signed_compact: A signed compact JWT.
138
139  Returns:
140    A (unsigned_compact, json_header, json_payload, signature_or_mac) tuple.
141  Raises:
142    _jwt_error.JwtInvalidError if it fails.
143  """
144  if not isinstance(signed_compact, str):
145    raise _jwt_error.JwtInvalidError('invalid token: not a str')
146  try:
147    encoded = signed_compact.encode('utf8')
148  except UnicodeEncodeError:
149    raise _jwt_error.JwtInvalidError('invalid token')
150  try:
151    unsigned_compact, encoded_signature = encoded.rsplit(b'.', 1)
152  except ValueError:
153    raise _jwt_error.JwtInvalidError('invalid token')
154  signature_or_mac = decode_signature(encoded_signature)
155  try:
156    encoded_header, encoded_payload = unsigned_compact.split(b'.')
157  except ValueError:
158    raise _jwt_error.JwtInvalidError('invalid token')
159
160  json_header = decode_header(encoded_header)
161  json_payload = decode_payload(encoded_payload)
162  return (unsigned_compact, json_header, json_payload, signature_or_mac)
163
164
165def _validate_kid_header(header: Any, kid: str) -> None:
166  if header['kid'] != kid:
167    raise _jwt_error.JwtInvalidError('invalid kid header')
168
169
170def validate_header(header: Any,
171                    algorithm: str,
172                    tink_kid: Optional[str] = None,
173                    custom_kid: Optional[str] = None) -> None:
174  """Parses the header and validates its values."""
175  _validate_algorithm(algorithm)
176  hdr_algorithm = header.get('alg', '')
177  if hdr_algorithm.upper() != algorithm:
178    raise _jwt_error.JwtInvalidError('Invalid algorithm; expected %s, got %s' %
179                                     (algorithm, hdr_algorithm))
180  if 'crit' in header:
181    raise _jwt_error.JwtInvalidError(
182        'all tokens with crit headers are rejected')
183  if tink_kid is not None and custom_kid is not None:
184    raise _jwt_error.JwtInvalidError('custom_kid can only be set for RAW keys')
185  if tink_kid is not None:
186    if 'kid' not in header:
187      # for output prefix type TINK, the kid header is required
188      raise _jwt_error.JwtInvalidError('missing kid in header')
189    _validate_kid_header(header, tink_kid)
190  if custom_kid is not None and 'kid' in header:
191    _validate_kid_header(header, custom_kid)
192
193
194def get_type_header(header: Any) -> Optional[str]:
195  return header.get('typ', None)
196
197
198def create_unsigned_compact(algorithm: str, kid: Optional[str],
199                            raw_jwt: _raw_jwt.RawJwt) -> bytes:
200  if raw_jwt.has_type_header():
201    header = create_header(algorithm, raw_jwt.type_header(), kid)
202  else:
203    header = create_header(algorithm, None, kid)
204  return header + b'.' + encode_payload(raw_jwt.json_payload())
205
206
207def create_signed_compact(unsigned_compact: bytes, signature: bytes) -> str:
208  return (unsigned_compact + b'.' + encode_signature(signature)).decode('utf8')
209