xref: /aosp_15_r20/external/tink/python/tink/jwt/_jwt_mac_wrapper.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Python primitive set wrapper for the JwtMac primitive."""
15
16from typing import Type
17
18from tink.proto import tink_pb2
19from tink import core
20from tink.jwt import _jwt_error
21from tink.jwt import _jwt_format
22from tink.jwt import _jwt_mac
23from tink.jwt import _jwt_validator
24from tink.jwt import _raw_jwt
25from tink.jwt import _verified_jwt
26
27
28class _WrappedJwtMac(_jwt_mac.JwtMac):
29  """A wrapped JwtMac."""
30
31  def __init__(self, pset: core.PrimitiveSet):
32    self._primitive_set = pset
33
34  def compute_mac_and_encode(self, raw_jwt: _raw_jwt.RawJwt) -> str:
35    """Computes a MAC and encodes the token.
36
37    Args:
38      raw_jwt: The RawJwt token to be MACed and encoded.
39
40    Returns:
41      The MACed token encoded in the JWS compact serialization format.
42    Raises:
43      tink.TinkError if the operation fails.
44    """
45    primary = self._primitive_set.primary()
46    kid = _jwt_format.get_kid(primary.key_id, primary.output_prefix_type)
47    return primary.primitive.compute_mac_and_encode_with_kid(raw_jwt, kid)
48
49  def verify_mac_and_decode(
50      self, compact: str,
51      validator: _jwt_validator.JwtValidator) -> _verified_jwt.VerifiedJwt:
52    """Verifies, validates and decodes a MACed compact JWT token.
53
54    Args:
55      compact: A MACed token encoded in the JWS compact serialization format.
56      validator: A JwtValidator that validates the token.
57
58    Returns:
59      A VerifiedJwt.
60    Raises:
61      tink.TinkError if the operation fails.
62    """
63    interesting_error = None
64    for entries in self._primitive_set.all():
65      for entry in entries:
66        try:
67          kid = _jwt_format.get_kid(entry.key_id, entry.output_prefix_type)
68          return entry.primitive.verify_mac_and_decode_with_kid(
69              compact, validator, kid)
70        except core.TinkError as e:
71          if isinstance(e, _jwt_error.JwtInvalidError):
72            interesting_error = e
73          pass
74    if interesting_error:
75      raise interesting_error
76    raise core.TinkError('invalid MAC')
77
78
79def _validate_primitive_set(pset: core.PrimitiveSet):
80  for entries in pset.all():
81    for entry in entries:
82      if (entry.output_prefix_type != tink_pb2.RAW and
83          entry.output_prefix_type != tink_pb2.TINK):
84        raise core.TinkError('unsupported OutputPrefixType')
85
86
87class _Wrapper(core.PrimitiveWrapper[_jwt_mac.JwtMacInternal, _jwt_mac.JwtMac]):
88  """A wrapper for JwtMac."""
89
90  def wrap(self, pset: core.PrimitiveSet) -> _jwt_mac.JwtMac:
91    _validate_primitive_set(pset)
92    return _WrappedJwtMac(pset)
93
94  def primitive_class(self) -> Type[_jwt_mac.JwtMac]:
95    return _jwt_mac.JwtMac
96
97  def input_primitive_class(self) -> Type[_jwt_mac.JwtMacInternal]:
98    return _jwt_mac.JwtMacInternal
99
100
101def register():
102  core.Registry.register_primitive_wrapper(_Wrapper())
103