1# Copyright 2021 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13"""The raw JSON Web Token (JWT).""" 14 15import copy 16import datetime 17import json 18 19from typing import cast, Mapping, Set, List, Dict, Optional, Union, Any 20 21from tink import core 22from tink.jwt import _json_util 23from tink.jwt import _jwt_error 24 25_REGISTERED_NAMES = frozenset({'iss', 'sub', 'jti', 'aud', 'exp', 'nbf', 'iat'}) 26 27_MAX_TIMESTAMP_VALUE = 253402300799 # 31 Dec 9999, 23:59:59 GMT 28 29Claim = Union[None, bool, int, float, str, List[Any], Dict[str, Any]] 30 31 32def _from_datetime(t: datetime.datetime) -> int: 33 if not t.tzinfo: 34 raise _jwt_error.JwtInvalidError('datetime must have tzinfo') 35 return int(t.timestamp()) 36 37 38def _to_datetime(timestamp: float) -> datetime.datetime: 39 return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc) 40 41 42def _validate_custom_claim_name(name: str) -> None: 43 if name in _REGISTERED_NAMES: 44 raise _jwt_error.JwtInvalidError( 45 'registered name %s cannot be custom claim name' % name) 46 47 48class RawJwt: 49 """An unencoded and unsigned JSON Web Token (JWT). 50 51 It contains all payload claims and a subset of the headers. It does not 52 contain any headers that depend on the key, such as "alg" or "kid", because 53 these headers are chosen when the token is signed and encoded, and should not 54 be chosen by the user. This ensures that the key can be changed without any 55 changes to the user code. 56 """ 57 58 def __new__(cls): 59 raise core.TinkError('RawJwt cannot be instantiated directly.') 60 61 def __init__(self, type_header: Optional[str], payload: Dict[str, 62 Any]) -> None: 63 # No need to copy payload, because only create and from_json_payload 64 # call this method. 65 if not isinstance(payload, Dict): 66 raise _jwt_error.JwtInvalidError('payload must be a dict') 67 self._type_header = type_header 68 self._payload = payload 69 self._validate_string_claim('iss') 70 self._validate_string_claim('sub') 71 self._validate_string_claim('jti') 72 self._validate_timestamp_claim('exp') 73 self._validate_timestamp_claim('nbf') 74 self._validate_timestamp_claim('iat') 75 self._validate_audience_claim() 76 77 def _validate_string_claim(self, name: str): 78 if name in self._payload: 79 if not isinstance(self._payload[name], str): 80 raise _jwt_error.JwtInvalidError('claim %s must be a String' % name) 81 82 def _validate_timestamp_claim(self, name: str): 83 if name in self._payload: 84 timestamp = self._payload[name] 85 if not isinstance(timestamp, (int, float)): 86 raise _jwt_error.JwtInvalidError('claim %s must be a Number' % name) 87 if timestamp > _MAX_TIMESTAMP_VALUE or timestamp < 0: 88 raise _jwt_error.JwtInvalidError( 89 'timestamp of claim %s is out of range' % name) 90 91 def _validate_audience_claim(self): 92 """The 'aud' claim must either be a string or a list of strings.""" 93 if 'aud' in self._payload: 94 audiences = self._payload['aud'] 95 if isinstance(audiences, str): 96 return 97 if not isinstance(audiences, list) or not audiences: 98 raise _jwt_error.JwtInvalidError('audiences cannot be an empty list') 99 if not all(isinstance(value, str) for value in audiences): 100 raise _jwt_error.JwtInvalidError('audiences must only contain strings') 101 102 # TODO(juerg): Consider adding a raw_ prefix to all access methods 103 def has_type_header(self) -> bool: 104 return self._type_header is not None 105 106 def type_header(self) -> str: 107 if not self.has_type_header(): 108 raise KeyError('type header is not set') 109 return self._type_header 110 111 def has_issuer(self) -> bool: 112 return 'iss' in self._payload 113 114 def issuer(self) -> str: 115 return cast(str, self._payload['iss']) 116 117 def has_subject(self) -> bool: 118 return 'sub' in self._payload 119 120 def subject(self) -> str: 121 return cast(str, self._payload['sub']) 122 123 def has_audiences(self) -> bool: 124 return 'aud' in self._payload 125 126 def audiences(self) -> List[str]: 127 aud = self._payload['aud'] 128 if isinstance(aud, str): 129 return [aud] 130 return list(aud) 131 132 def has_jwt_id(self) -> bool: 133 return 'jti' in self._payload 134 135 def jwt_id(self) -> str: 136 return cast(str, self._payload['jti']) 137 138 def has_expiration(self) -> bool: 139 return 'exp' in self._payload 140 141 def expiration(self) -> datetime.datetime: 142 return _to_datetime(self._payload['exp']) 143 144 def has_not_before(self) -> bool: 145 return 'nbf' in self._payload 146 147 def not_before(self) -> datetime.datetime: 148 return _to_datetime(self._payload['nbf']) 149 150 def has_issued_at(self) -> bool: 151 return 'iat' in self._payload 152 153 def issued_at(self) -> datetime.datetime: 154 return _to_datetime(self._payload['iat']) 155 156 def custom_claim_names(self) -> Set[str]: 157 return {n for n in self._payload.keys() if n not in _REGISTERED_NAMES} 158 159 def custom_claim(self, name: str) -> Claim: 160 _validate_custom_claim_name(name) 161 value = self._payload[name] 162 if isinstance(value, (list, dict)): 163 return copy.deepcopy(value) 164 else: 165 return value 166 167 def json_payload(self) -> str: 168 """Returns the payload encoded as JSON string.""" 169 return _json_util.json_dumps(self._payload) 170 171 @classmethod 172 def create(cls, 173 *, 174 type_header: Optional[str] = None, 175 issuer: Optional[str] = None, 176 subject: Optional[str] = None, 177 audience: Optional[str] = None, 178 audiences: Optional[List[str]] = None, 179 jwt_id: Optional[str] = None, 180 expiration: Optional[datetime.datetime] = None, 181 without_expiration: Optional[bool] = None, 182 not_before: Optional[datetime.datetime] = None, 183 issued_at: Optional[datetime.datetime] = None, 184 custom_claims: Optional[Mapping[str, Claim]] = None) -> 'RawJwt': 185 """Create a new RawJwt instance.""" 186 if not expiration and not without_expiration: 187 raise ValueError('either expiration or without_expiration must be set') 188 if expiration and without_expiration: 189 raise ValueError( 190 'expiration and without_expiration cannot be set at the same time') 191 if audience is not None and audiences is not None: 192 raise _jwt_error.JwtInvalidError( 193 'audience and audiences cannot be set at the same time') 194 payload = {} 195 if issuer: 196 payload['iss'] = issuer 197 if subject: 198 payload['sub'] = subject 199 if jwt_id is not None: 200 payload['jti'] = jwt_id 201 if audience is not None: 202 payload['aud'] = audience 203 if audiences is not None: 204 payload['aud'] = copy.copy(audiences) 205 if expiration: 206 payload['exp'] = _from_datetime(expiration) 207 if not_before: 208 payload['nbf'] = _from_datetime(not_before) 209 if issued_at: 210 payload['iat'] = _from_datetime(issued_at) 211 if custom_claims: 212 for name, value in custom_claims.items(): 213 _validate_custom_claim_name(name) 214 if not isinstance(name, str): 215 raise _jwt_error.JwtInvalidError('claim name must be Text') 216 if (value is None or isinstance(value, (bool, int, float, str))): 217 payload[name] = value 218 elif isinstance(value, list): 219 payload[name] = json.loads(json.dumps(value)) 220 elif isinstance(value, dict): 221 payload[name] = json.loads(json.dumps(value)) 222 else: 223 raise _jwt_error.JwtInvalidError('claim %s has unknown type' % name) 224 raw_jwt = object.__new__(cls) 225 raw_jwt.__init__(type_header, payload) 226 return raw_jwt 227 228 @classmethod 229 def _from_json(cls, type_header: Optional[str], payload: str) -> 'RawJwt': 230 """Creates a RawJwt from payload encoded as JSON string.""" 231 raw_jwt = object.__new__(cls) 232 raw_jwt.__init__(type_header, _json_util.json_loads(payload)) 233 return raw_jwt 234 235 236def new_raw_jwt(*, 237 type_header: Optional[str] = None, 238 issuer: Optional[str] = None, 239 subject: Optional[str] = None, 240 audience: Optional[str] = None, 241 audiences: Optional[List[str]] = None, 242 jwt_id: Optional[str] = None, 243 expiration: Optional[datetime.datetime] = None, 244 without_expiration: bool = False, 245 not_before: Optional[datetime.datetime] = None, 246 issued_at: Optional[datetime.datetime] = None, 247 custom_claims: Optional[Mapping[str, Claim]] = None) -> RawJwt: 248 """Creates a new RawJwt.""" 249 return RawJwt.create( 250 type_header=type_header, 251 issuer=issuer, 252 subject=subject, 253 audience=audience, 254 audiences=audiences, 255 jwt_id=jwt_id, 256 expiration=expiration, 257 without_expiration=without_expiration, 258 not_before=not_before, 259 issued_at=issued_at, 260 custom_claims=custom_claims) 261 262 263def raw_jwt_from_json(type_header: Optional[str], payload: str) -> RawJwt: 264 """Internal function used to verify JWT token.""" 265 return RawJwt._from_json(type_header, payload) # pylint: disable=protected-access 266