xref: /aosp_15_r20/external/tink/python/tink/jwt/_jwt_signature_wrappers.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2021 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Python primitive set wrapper for the JwtMac primitive."""
15
16from typing import Type
17
18from tink.proto import tink_pb2
19from tink import core
20from tink.jwt import _jwt_error
21from tink.jwt import _jwt_format
22from tink.jwt import _jwt_public_key_sign
23from tink.jwt import _jwt_public_key_verify
24from tink.jwt import _jwt_validator
25from tink.jwt import _raw_jwt
26from tink.jwt import _verified_jwt
27
28
29class _WrappedJwtPublicKeySign(_jwt_public_key_sign.JwtPublicKeySign):
30  """A wrapped JwtPublicKeySign."""
31
32  def __init__(self, pset: core.PrimitiveSet):
33    self._primitive_set = pset
34
35  def sign_and_encode(self, raw_jwt: _raw_jwt.RawJwt) -> str:
36    primary = self._primitive_set.primary()
37    kid = _jwt_format.get_kid(primary.key_id, primary.output_prefix_type)
38    return primary.primitive.sign_and_encode_with_kid(raw_jwt, kid)
39
40
41class _WrappedJwtPublicKeyVerify(_jwt_public_key_verify.JwtPublicKeyVerify):
42  """A wrapped JwtPublicKeyVerify."""
43
44  def __init__(self, pset: core.PrimitiveSet):
45    self._primitive_set = pset
46
47  def verify_and_decode(
48      self, compact: str,
49      validator: _jwt_validator.JwtValidator) -> _verified_jwt.VerifiedJwt:
50    interesting_error = None
51    for entries in self._primitive_set.all():
52      for entry in entries:
53        try:
54          kid = _jwt_format.get_kid(entry.key_id, entry.output_prefix_type)
55          return entry.primitive.verify_and_decode_with_kid(
56              compact, validator, kid)
57        except core.TinkError as e:
58          if isinstance(e, _jwt_error.JwtInvalidError):
59            interesting_error = e
60          pass
61    if interesting_error:
62      raise interesting_error
63    raise core.TinkError('invalid signature')
64
65
66def _validate_primitive_set(pset: core.PrimitiveSet):
67  for entries in pset.all():
68    for entry in entries:
69      if (entry.output_prefix_type != tink_pb2.RAW and
70          entry.output_prefix_type != tink_pb2.TINK):
71        raise core.TinkError('unsupported OutputPrefixType')
72
73
74class _JwtPublicKeySignWrapper(
75    core.PrimitiveWrapper[_jwt_public_key_sign.JwtPublicKeySignInternal,
76                          _jwt_public_key_sign.JwtPublicKeySign]):
77  """A wrapper for JwtPublicKeySign."""
78
79  def wrap(self,
80           pset: core.PrimitiveSet) -> _jwt_public_key_sign.JwtPublicKeySign:
81    _validate_primitive_set(pset)
82    return _WrappedJwtPublicKeySign(pset)
83
84  def primitive_class(self) -> Type[_jwt_public_key_sign.JwtPublicKeySign]:
85    return _jwt_public_key_sign.JwtPublicKeySign
86
87  def input_primitive_class(
88      self) -> Type[_jwt_public_key_sign.JwtPublicKeySignInternal]:
89    return _jwt_public_key_sign.JwtPublicKeySignInternal
90
91
92class _JwtPublicKeyVerifyWrapper(
93    core.PrimitiveWrapper[_jwt_public_key_verify.JwtPublicKeyVerifyInternal,
94                          _jwt_public_key_verify.JwtPublicKeyVerify]):
95  """A wrapper for JwtPublicKeyVerify."""
96
97  def wrap(
98      self,
99      pset: core.PrimitiveSet) -> _jwt_public_key_verify.JwtPublicKeyVerify:
100    _validate_primitive_set(pset)
101    return _WrappedJwtPublicKeyVerify(pset)
102
103  def primitive_class(self) -> Type[_jwt_public_key_verify.JwtPublicKeyVerify]:
104    return _jwt_public_key_verify.JwtPublicKeyVerify
105
106  def input_primitive_class(
107      self) -> Type[_jwt_public_key_verify.JwtPublicKeyVerifyInternal]:
108    return _jwt_public_key_verify.JwtPublicKeyVerifyInternal
109
110
111def register():
112  core.Registry.register_primitive_wrapper(_JwtPublicKeySignWrapper())
113  core.Registry.register_primitive_wrapper(_JwtPublicKeyVerifyWrapper())
114