xref: /aosp_15_r20/external/tink/python/tink/jwt/_raw_jwt.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13"""The raw JSON Web Token (JWT)."""
14
15import copy
16import datetime
17import json
18
19from typing import cast, Mapping, Set, List, Dict, Optional, Union, Any
20
21from tink import core
22from tink.jwt import _json_util
23from tink.jwt import _jwt_error
24
25_REGISTERED_NAMES = frozenset({'iss', 'sub', 'jti', 'aud', 'exp', 'nbf', 'iat'})
26
27_MAX_TIMESTAMP_VALUE = 253402300799  # 31 Dec 9999, 23:59:59 GMT
28
29Claim = Union[None, bool, int, float, str, List[Any], Dict[str, Any]]
30
31
32def _from_datetime(t: datetime.datetime) -> int:
33  if not t.tzinfo:
34    raise _jwt_error.JwtInvalidError('datetime must have tzinfo')
35  return int(t.timestamp())
36
37
38def _to_datetime(timestamp: float) -> datetime.datetime:
39  return datetime.datetime.fromtimestamp(timestamp, datetime.timezone.utc)
40
41
42def _validate_custom_claim_name(name: str) -> None:
43  if name in _REGISTERED_NAMES:
44    raise _jwt_error.JwtInvalidError(
45        'registered name %s cannot be custom claim name' % name)
46
47
48class RawJwt:
49  """An unencoded and unsigned JSON Web Token (JWT).
50
51  It contains all payload claims and a subset of the headers. It does not
52  contain any headers that depend on the key, such as "alg" or "kid", because
53  these headers are chosen when the token is signed and encoded, and should not
54  be chosen by the user. This ensures that the key can be changed without any
55  changes to the user code.
56  """
57
58  def __new__(cls):
59    raise core.TinkError('RawJwt cannot be instantiated directly.')
60
61  def __init__(self, type_header: Optional[str], payload: Dict[str,
62                                                               Any]) -> None:
63    # No need to copy payload, because only create and from_json_payload
64    # call this method.
65    if not isinstance(payload, Dict):
66      raise _jwt_error.JwtInvalidError('payload must be a dict')
67    self._type_header = type_header
68    self._payload = payload
69    self._validate_string_claim('iss')
70    self._validate_string_claim('sub')
71    self._validate_string_claim('jti')
72    self._validate_timestamp_claim('exp')
73    self._validate_timestamp_claim('nbf')
74    self._validate_timestamp_claim('iat')
75    self._validate_audience_claim()
76
77  def _validate_string_claim(self, name: str):
78    if name in self._payload:
79      if not isinstance(self._payload[name], str):
80        raise _jwt_error.JwtInvalidError('claim %s must be a String' % name)
81
82  def _validate_timestamp_claim(self, name: str):
83    if name in self._payload:
84      timestamp = self._payload[name]
85      if not isinstance(timestamp, (int, float)):
86        raise _jwt_error.JwtInvalidError('claim %s must be a Number' % name)
87      if timestamp > _MAX_TIMESTAMP_VALUE or timestamp < 0:
88        raise _jwt_error.JwtInvalidError(
89            'timestamp of claim %s is out of range' % name)
90
91  def _validate_audience_claim(self):
92    """The 'aud' claim must either be a string or a list of strings."""
93    if 'aud' in self._payload:
94      audiences = self._payload['aud']
95      if isinstance(audiences, str):
96        return
97      if not isinstance(audiences, list) or not audiences:
98        raise _jwt_error.JwtInvalidError('audiences cannot be an empty list')
99      if not all(isinstance(value, str) for value in audiences):
100        raise _jwt_error.JwtInvalidError('audiences must only contain strings')
101
102  # TODO(juerg): Consider adding a raw_ prefix to all access methods
103  def has_type_header(self) -> bool:
104    return self._type_header is not None
105
106  def type_header(self) -> str:
107    if not self.has_type_header():
108      raise KeyError('type header is not set')
109    return self._type_header
110
111  def has_issuer(self) -> bool:
112    return 'iss' in self._payload
113
114  def issuer(self) -> str:
115    return cast(str, self._payload['iss'])
116
117  def has_subject(self) -> bool:
118    return 'sub' in self._payload
119
120  def subject(self) -> str:
121    return cast(str, self._payload['sub'])
122
123  def has_audiences(self) -> bool:
124    return 'aud' in self._payload
125
126  def audiences(self) -> List[str]:
127    aud = self._payload['aud']
128    if isinstance(aud, str):
129      return [aud]
130    return list(aud)
131
132  def has_jwt_id(self) -> bool:
133    return 'jti' in self._payload
134
135  def jwt_id(self) -> str:
136    return cast(str, self._payload['jti'])
137
138  def has_expiration(self) -> bool:
139    return 'exp' in self._payload
140
141  def expiration(self) -> datetime.datetime:
142    return _to_datetime(self._payload['exp'])
143
144  def has_not_before(self) -> bool:
145    return 'nbf' in self._payload
146
147  def not_before(self) -> datetime.datetime:
148    return _to_datetime(self._payload['nbf'])
149
150  def has_issued_at(self) -> bool:
151    return 'iat' in self._payload
152
153  def issued_at(self) -> datetime.datetime:
154    return _to_datetime(self._payload['iat'])
155
156  def custom_claim_names(self) -> Set[str]:
157    return {n for n in self._payload.keys() if n not in _REGISTERED_NAMES}
158
159  def custom_claim(self, name: str) -> Claim:
160    _validate_custom_claim_name(name)
161    value = self._payload[name]
162    if isinstance(value, (list, dict)):
163      return copy.deepcopy(value)
164    else:
165      return value
166
167  def json_payload(self) -> str:
168    """Returns the payload encoded as JSON string."""
169    return _json_util.json_dumps(self._payload)
170
171  @classmethod
172  def create(cls,
173             *,
174             type_header: Optional[str] = None,
175             issuer: Optional[str] = None,
176             subject: Optional[str] = None,
177             audience: Optional[str] = None,
178             audiences: Optional[List[str]] = None,
179             jwt_id: Optional[str] = None,
180             expiration: Optional[datetime.datetime] = None,
181             without_expiration: Optional[bool] = None,
182             not_before: Optional[datetime.datetime] = None,
183             issued_at: Optional[datetime.datetime] = None,
184             custom_claims: Optional[Mapping[str, Claim]] = None) -> 'RawJwt':
185    """Create a new RawJwt instance."""
186    if not expiration and not without_expiration:
187      raise ValueError('either expiration or without_expiration must be set')
188    if expiration and without_expiration:
189      raise ValueError(
190          'expiration and without_expiration cannot be set at the same time')
191    if audience is not None and audiences is not None:
192      raise _jwt_error.JwtInvalidError(
193          'audience and audiences cannot be set at the same time')
194    payload = {}
195    if issuer:
196      payload['iss'] = issuer
197    if subject:
198      payload['sub'] = subject
199    if jwt_id is not None:
200      payload['jti'] = jwt_id
201    if audience is not None:
202      payload['aud'] = audience
203    if audiences is not None:
204      payload['aud'] = copy.copy(audiences)
205    if expiration:
206      payload['exp'] = _from_datetime(expiration)
207    if not_before:
208      payload['nbf'] = _from_datetime(not_before)
209    if issued_at:
210      payload['iat'] = _from_datetime(issued_at)
211    if custom_claims:
212      for name, value in custom_claims.items():
213        _validate_custom_claim_name(name)
214        if not isinstance(name, str):
215          raise _jwt_error.JwtInvalidError('claim name must be Text')
216        if (value is None or isinstance(value, (bool, int, float, str))):
217          payload[name] = value
218        elif isinstance(value, list):
219          payload[name] = json.loads(json.dumps(value))
220        elif isinstance(value, dict):
221          payload[name] = json.loads(json.dumps(value))
222        else:
223          raise _jwt_error.JwtInvalidError('claim %s has unknown type' % name)
224    raw_jwt = object.__new__(cls)
225    raw_jwt.__init__(type_header, payload)
226    return raw_jwt
227
228  @classmethod
229  def _from_json(cls, type_header: Optional[str], payload: str) -> 'RawJwt':
230    """Creates a RawJwt from payload encoded as JSON string."""
231    raw_jwt = object.__new__(cls)
232    raw_jwt.__init__(type_header, _json_util.json_loads(payload))
233    return raw_jwt
234
235
236def new_raw_jwt(*,
237                type_header: Optional[str] = None,
238                issuer: Optional[str] = None,
239                subject: Optional[str] = None,
240                audience: Optional[str] = None,
241                audiences: Optional[List[str]] = None,
242                jwt_id: Optional[str] = None,
243                expiration: Optional[datetime.datetime] = None,
244                without_expiration: bool = False,
245                not_before: Optional[datetime.datetime] = None,
246                issued_at: Optional[datetime.datetime] = None,
247                custom_claims: Optional[Mapping[str, Claim]] = None) -> RawJwt:
248  """Creates a new RawJwt."""
249  return RawJwt.create(
250      type_header=type_header,
251      issuer=issuer,
252      subject=subject,
253      audience=audience,
254      audiences=audiences,
255      jwt_id=jwt_id,
256      expiration=expiration,
257      without_expiration=without_expiration,
258      not_before=not_before,
259      issued_at=issued_at,
260      custom_claims=custom_claims)
261
262
263def raw_jwt_from_json(type_header: Optional[str], payload: str) -> RawJwt:
264  """Internal function used to verify JWT token."""
265  return RawJwt._from_json(type_header, payload)  # pylint: disable=protected-access
266