1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Functions that help to serialize and deserialize from/to the JWT format.""" 15 16import base64 17import binascii 18import struct 19from typing import Any, Optional, Tuple 20 21from tink.proto import tink_pb2 22from tink.jwt import _json_util 23from tink.jwt import _jwt_error 24from tink.jwt import _raw_jwt 25 26_VALID_ALGORITHMS = frozenset({ 27 'HS256', 'HS384', 'HS512', 'ES256', 'ES384', 'ES512', 'RS256', 'RS384', 28 'RS384', 'RS512', 'PS256', 'PS384', 'PS512' 29}) 30 31 32def base64_encode(data: bytes) -> bytes: 33 """Does a URL-safe base64 encoding without padding.""" 34 return base64.urlsafe_b64encode(data).rstrip(b'=') 35 36 37def _is_valid_urlsafe_base64_char(c: int) -> bool: 38 if c >= ord('a') and c <= ord('z'): 39 return True 40 if c >= ord('A') and c <= ord('Z'): 41 return True 42 if c >= ord('0') and c <= ord('9'): 43 return True 44 if c == ord('-') or c == ord('_'): 45 return True 46 return False 47 48 49def base64_decode(encoded_data: bytes) -> bytes: 50 """Does a URL-safe base64 decoding without padding.""" 51 # base64.urlsafe_b64decode ignores all non-base64 chars. We don't want that. 52 for c in encoded_data: 53 if not _is_valid_urlsafe_base64_char(c): 54 raise _jwt_error.JwtInvalidError('invalid base64 encoding') 55 # base64.urlsafe_b64decode requires padding, but does not mind too much 56 # padding. So we simply add the maximum amount of padding needed. 57 padded_encoded_data = encoded_data + b'===' 58 try: 59 return base64.urlsafe_b64decode(padded_encoded_data) 60 except binascii.Error: 61 # Throws when the length of encoded_data is (4*i + 1) for some i 62 raise _jwt_error.JwtInvalidError('invalid base64 encoding') 63 64 65def _validate_algorithm(algorithm: str) -> None: 66 if algorithm not in _VALID_ALGORITHMS: 67 raise _jwt_error.JwtInvalidError('Invalid algorithm %s' % algorithm) 68 69 70def encode_header(json_header: str) -> bytes: 71 try: 72 return base64_encode(json_header.encode('utf8')) 73 except UnicodeEncodeError: 74 raise _jwt_error.JwtInvalidError('invalid token') 75 76 77def decode_header(encoded_header: bytes) -> str: 78 try: 79 return base64_decode(encoded_header).decode('utf8') 80 except UnicodeDecodeError: 81 raise _jwt_error.JwtInvalidError('invalid token') 82 83 84def encode_payload(json_payload: str) -> bytes: 85 """Encodes the payload into compact form.""" 86 try: 87 return base64_encode(json_payload.encode('utf8')) 88 except UnicodeEncodeError: 89 raise _jwt_error.JwtInvalidError('invalid token') 90 91 92def decode_payload(encoded_payload: bytes) -> str: 93 """Decodes the payload from compact form.""" 94 try: 95 return base64_decode(encoded_payload).decode('utf8') 96 except UnicodeDecodeError: 97 raise _jwt_error.JwtInvalidError('invalid token') 98 99 100def encode_signature(signature: bytes) -> bytes: 101 """Encodes the signature.""" 102 return base64_encode(signature) 103 104 105def decode_signature(encoded_signature: bytes) -> bytes: 106 """Decodes the signature.""" 107 return base64_decode(encoded_signature) 108 109 110def create_header(algorithm: str, type_header: Optional[str], 111 kid: Optional[str]) -> bytes: 112 _validate_algorithm(algorithm) 113 header = {} 114 if kid: 115 header['kid'] = kid 116 header['alg'] = algorithm 117 if type_header: 118 header['typ'] = type_header 119 return encode_header(_json_util.json_dumps(header)) 120 121 122def get_kid(key_id: int, prefix: tink_pb2.OutputPrefixType) -> Optional[str]: 123 """Returns the encoded key_id, or None.""" 124 if prefix == tink_pb2.RAW: 125 return None 126 if prefix == tink_pb2.TINK: 127 if key_id < 0 or key_id > 2**32: 128 raise _jwt_error.JwtInvalidError('invalid key_id') 129 return base64_encode(struct.pack('>L', key_id)).decode('utf8') 130 raise _jwt_error.JwtInvalidError('unexpected output prefix type') 131 132 133def split_signed_compact(signed_compact: str) -> Tuple[bytes, str, str, bytes]: 134 """Splits a signed compact into its parts. 135 136 Args: 137 signed_compact: A signed compact JWT. 138 139 Returns: 140 A (unsigned_compact, json_header, json_payload, signature_or_mac) tuple. 141 Raises: 142 _jwt_error.JwtInvalidError if it fails. 143 """ 144 if not isinstance(signed_compact, str): 145 raise _jwt_error.JwtInvalidError('invalid token: not a str') 146 try: 147 encoded = signed_compact.encode('utf8') 148 except UnicodeEncodeError: 149 raise _jwt_error.JwtInvalidError('invalid token') 150 try: 151 unsigned_compact, encoded_signature = encoded.rsplit(b'.', 1) 152 except ValueError: 153 raise _jwt_error.JwtInvalidError('invalid token') 154 signature_or_mac = decode_signature(encoded_signature) 155 try: 156 encoded_header, encoded_payload = unsigned_compact.split(b'.') 157 except ValueError: 158 raise _jwt_error.JwtInvalidError('invalid token') 159 160 json_header = decode_header(encoded_header) 161 json_payload = decode_payload(encoded_payload) 162 return (unsigned_compact, json_header, json_payload, signature_or_mac) 163 164 165def _validate_kid_header(header: Any, kid: str) -> None: 166 if header['kid'] != kid: 167 raise _jwt_error.JwtInvalidError('invalid kid header') 168 169 170def validate_header(header: Any, 171 algorithm: str, 172 tink_kid: Optional[str] = None, 173 custom_kid: Optional[str] = None) -> None: 174 """Parses the header and validates its values.""" 175 _validate_algorithm(algorithm) 176 hdr_algorithm = header.get('alg', '') 177 if hdr_algorithm.upper() != algorithm: 178 raise _jwt_error.JwtInvalidError('Invalid algorithm; expected %s, got %s' % 179 (algorithm, hdr_algorithm)) 180 if 'crit' in header: 181 raise _jwt_error.JwtInvalidError( 182 'all tokens with crit headers are rejected') 183 if tink_kid is not None and custom_kid is not None: 184 raise _jwt_error.JwtInvalidError('custom_kid can only be set for RAW keys') 185 if tink_kid is not None: 186 if 'kid' not in header: 187 # for output prefix type TINK, the kid header is required 188 raise _jwt_error.JwtInvalidError('missing kid in header') 189 _validate_kid_header(header, tink_kid) 190 if custom_kid is not None and 'kid' in header: 191 _validate_kid_header(header, custom_kid) 192 193 194def get_type_header(header: Any) -> Optional[str]: 195 return header.get('typ', None) 196 197 198def create_unsigned_compact(algorithm: str, kid: Optional[str], 199 raw_jwt: _raw_jwt.RawJwt) -> bytes: 200 if raw_jwt.has_type_header(): 201 header = create_header(algorithm, raw_jwt.type_header(), kid) 202 else: 203 header = create_header(algorithm, None, kid) 204 return header + b'.' + encode_payload(raw_jwt.json_payload()) 205 206 207def create_signed_compact(unsigned_compact: bytes, signature: bytes) -> str: 208 return (unsigned_compact + b'.' + encode_signature(signature)).decode('utf8') 209