xref: /aosp_15_r20/external/tink/testing/cross_language/util/_primitives.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Implements tink primitives from gRPC testing_api stubs."""
15
16import datetime
17import io
18import json
19from typing import BinaryIO, Dict, Optional, Mapping, Tuple
20
21import tink
22from tink import aead
23from tink import daead
24from tink import hybrid
25from tink import jwt
26from tink import mac
27from tink import prf
28from tink import signature as tink_signature
29from tink import streaming_aead
30
31from tink.proto import tink_pb2
32from protos import testing_api_pb2
33from protos import testing_api_pb2_grpc
34
35
36def key_template(stub: testing_api_pb2_grpc.KeysetStub,
37                 template_name: str) -> tink_pb2.KeyTemplate:
38  request = testing_api_pb2.KeysetTemplateRequest(template_name=template_name)
39  response = stub.GetTemplate(request)
40  if response.err:
41    raise tink.TinkError(response.err)
42  return tink_pb2.KeyTemplate.FromString(response.key_template)
43
44
45def new_keyset(stub: testing_api_pb2_grpc.KeysetStub,
46               template: tink_pb2.KeyTemplate) -> bytes:
47  gen_request = testing_api_pb2.KeysetGenerateRequest(
48      template=template.SerializeToString())
49  gen_response = stub.Generate(gen_request)
50  if gen_response.err:
51    raise tink.TinkError(gen_response.err)
52  return gen_response.keyset
53
54
55def public_keyset(stub: testing_api_pb2_grpc.KeysetStub,
56                  private_keyset: bytes) -> bytes:
57  request = testing_api_pb2.KeysetPublicRequest(private_keyset=private_keyset)
58  response = stub.Public(request)
59  if response.err:
60    raise tink.TinkError(response.err)
61  return response.public_keyset
62
63
64def keyset_to_json(
65    stub: testing_api_pb2_grpc.KeysetStub,
66    keyset: bytes) -> str:
67  request = testing_api_pb2.KeysetToJsonRequest(keyset=keyset)
68  response = stub.ToJson(request)
69  if response.err:
70    raise tink.TinkError(response.err)
71  return response.json_keyset
72
73
74def keyset_from_json(
75    stub: testing_api_pb2_grpc.KeysetStub,
76    json_keyset: str) -> bytes:
77  request = testing_api_pb2.KeysetFromJsonRequest(json_keyset=json_keyset)
78  response = stub.FromJson(request)
79  if response.err:
80    raise tink.TinkError(response.err)
81  return response.keyset
82
83
84def keyset_read_encrypted(stub: testing_api_pb2_grpc.KeysetStub,
85                          encrypted_keyset: bytes, master_keyset: bytes,
86                          associated_data: Optional[bytes],
87                          keyset_reader_type: str) -> bytes:
88  """Reads an encrypted keyset."""
89  request = testing_api_pb2.KeysetReadEncryptedRequest(
90      encrypted_keyset=encrypted_keyset,
91      master_keyset=master_keyset,
92      keyset_reader_type=testing_api_pb2.KeysetReaderType.Value(
93          keyset_reader_type))
94  if associated_data is not None:
95    request.associated_data.value = associated_data
96  response = stub.ReadEncrypted(request)
97  if response.err:
98    raise tink.TinkError(response.err)
99  return response.keyset
100
101
102def keyset_write_encrypted(stub: testing_api_pb2_grpc.KeysetStub, keyset: bytes,
103                           master_keyset: bytes,
104                           associated_data: Optional[bytes],
105                           keyset_writer_type: str) -> bytes:
106  """Writes an encrypted keyset."""
107  request = testing_api_pb2.KeysetWriteEncryptedRequest(
108      keyset=keyset,
109      master_keyset=master_keyset,
110      keyset_writer_type=testing_api_pb2.KeysetWriterType.Value(
111          keyset_writer_type))
112  if associated_data is not None:
113    request.associated_data.value = associated_data
114  response = stub.WriteEncrypted(request)
115  if response.err:
116    raise tink.TinkError(response.err)
117  return response.encrypted_keyset
118
119
120def jwk_set_to_keyset(stub: testing_api_pb2_grpc.JwtStub,
121                      jwk_set: str) -> bytes:
122  request = testing_api_pb2.JwtFromJwkSetRequest(jwk_set=jwk_set)
123  response = stub.FromJwkSet(request)
124  if response.err:
125    raise tink.TinkError(response.err)
126  return response.keyset
127
128
129def jwk_set_from_keyset(stub: testing_api_pb2_grpc.JwtStub,
130                        keyset: bytes) -> str:
131  request = testing_api_pb2.JwtToJwkSetRequest(keyset=keyset)
132  response = stub.ToJwkSet(request)
133  if response.err:
134    raise tink.TinkError(response.err)
135  return response.jwk_set
136
137
138class Aead(aead.Aead):
139  """Wraps AEAD service stub into an Aead primitive."""
140
141  def __init__(self, lang: str, stub: testing_api_pb2_grpc.AeadStub,
142               keyset: bytes, annotations: Optional[Dict[str, str]]) -> None:
143    self.lang = lang
144    self._stub = stub
145    self._keyset = keyset
146    self._annotations = annotations
147    creation_response = self._stub.Create(
148        testing_api_pb2.CreationRequest(
149            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
150                serialized_keyset=self._keyset,
151                annotations=self._annotations)))
152    if creation_response.err:
153      raise tink.TinkError(creation_response.err)
154
155  def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
156    enc_request = testing_api_pb2.AeadEncryptRequest(
157        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
158            serialized_keyset=self._keyset, annotations=self._annotations),
159        plaintext=plaintext,
160        associated_data=associated_data)
161    enc_response = self._stub.Encrypt(enc_request)
162    if enc_response.err:
163      raise tink.TinkError(enc_response.err)
164    return enc_response.ciphertext
165
166  def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
167    dec_request = testing_api_pb2.AeadDecryptRequest(
168        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
169            serialized_keyset=self._keyset, annotations=self._annotations),
170        ciphertext=ciphertext,
171        associated_data=associated_data)
172    dec_response = self._stub.Decrypt(dec_request)
173    if dec_response.err:
174      raise tink.TinkError(dec_response.err)
175    return dec_response.plaintext
176
177
178class DeterministicAead(daead.DeterministicAead):
179  """Wraps DAEAD services stub into an DeterministicAead primitive."""
180
181  def __init__(self, lang: str,
182               stub: testing_api_pb2_grpc.DeterministicAeadStub, keyset: bytes,
183               annotations: Optional[Dict[str, str]]) -> None:
184    self.lang = lang
185    self._stub = stub
186    self._keyset = keyset
187    self._annotations = annotations
188    creation_response = self._stub.Create(
189        testing_api_pb2.CreationRequest(
190            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
191                annotations=self._annotations, serialized_keyset=self._keyset)))
192    if creation_response.err:
193      raise tink.TinkError(creation_response.err)
194
195  def encrypt_deterministically(self, plaintext: bytes,
196                                associated_data: bytes) -> bytes:
197    """Encrypts."""
198    enc_request = testing_api_pb2.DeterministicAeadEncryptRequest(
199        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
200            serialized_keyset=self._keyset, annotations=self._annotations),
201        plaintext=plaintext,
202        associated_data=associated_data)
203    enc_response = self._stub.EncryptDeterministically(enc_request)
204    if enc_response.err:
205      raise tink.TinkError(enc_response.err)
206    return enc_response.ciphertext
207
208  def decrypt_deterministically(self, ciphertext: bytes,
209                                associated_data: bytes) -> bytes:
210    """Decrypts."""
211    dec_request = testing_api_pb2.DeterministicAeadDecryptRequest(
212        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
213            serialized_keyset=self._keyset, annotations=self._annotations),
214        ciphertext=ciphertext,
215        associated_data=associated_data)
216    dec_response = self._stub.DecryptDeterministically(dec_request)
217    if dec_response.err:
218      raise tink.TinkError(dec_response.err)
219    return dec_response.plaintext
220
221
222class StreamingAead(streaming_aead.StreamingAead):
223  """Wraps Streaming AEAD service stub into a StreamingAead primitive."""
224
225  def __init__(self, lang: str, stub: testing_api_pb2_grpc.StreamingAeadStub,
226               keyset: bytes) -> None:
227    self.lang = lang
228    self._stub = stub
229    self._keyset = keyset
230    creation_response = self._stub.Create(
231        testing_api_pb2.CreationRequest(
232            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
233                serialized_keyset=self._keyset)))
234    if creation_response.err:
235      raise tink.TinkError(creation_response.err)
236
237  def new_encrypting_stream(self, plaintext: BinaryIO,
238                            associated_data: bytes) -> BinaryIO:
239    enc_request = testing_api_pb2.StreamingAeadEncryptRequest(
240        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
241            serialized_keyset=self._keyset),
242        plaintext=plaintext.read(),
243        associated_data=associated_data)
244    enc_response = self._stub.Encrypt(enc_request)
245    if enc_response.err:
246      raise tink.TinkError(enc_response.err)
247    return io.BytesIO(enc_response.ciphertext)
248
249  def new_decrypting_stream(self, ciphertext: BinaryIO,
250                            associated_data: bytes) -> BinaryIO:
251    dec_request = testing_api_pb2.StreamingAeadDecryptRequest(
252        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
253            serialized_keyset=self._keyset),
254        ciphertext=ciphertext.read(),
255        associated_data=associated_data)
256    dec_response = self._stub.Decrypt(dec_request)
257    if dec_response.err:
258      raise tink.TinkError(dec_response.err)
259    return io.BytesIO(dec_response.plaintext)
260
261
262class Mac(mac.Mac):
263  """Wraps MAC service stub into an Mac primitive."""
264
265  def __init__(self, lang: str, stub: testing_api_pb2_grpc.MacStub,
266               keyset: bytes, annotations: Optional[Dict[str, str]]) -> None:
267    self.lang = lang
268    self._stub = stub
269    self._keyset = keyset
270    self._annotations = annotations
271    creation_response = self._stub.Create(
272        testing_api_pb2.CreationRequest(
273            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
274                serialized_keyset=self._keyset, annotations=self._annotations)))
275    if creation_response.err:
276      raise tink.TinkError(creation_response.err)
277
278  def compute_mac(self, data: bytes) -> bytes:
279    request = testing_api_pb2.ComputeMacRequest(
280        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
281            serialized_keyset=self._keyset, annotations=self._annotations),
282        data=data)
283    response = self._stub.ComputeMac(request)
284    if response.err:
285      raise tink.TinkError(response.err)
286    return response.mac_value
287
288  def verify_mac(self, mac_value: bytes, data: bytes) -> None:
289    request = testing_api_pb2.VerifyMacRequest(
290        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
291            serialized_keyset=self._keyset, annotations=self._annotations),
292        mac_value=mac_value,
293        data=data)
294    response = self._stub.VerifyMac(request)
295    if response.err:
296      raise tink.TinkError(response.err)
297
298
299class HybridEncrypt(hybrid.HybridEncrypt):
300  """Implements the HybridEncrypt primitive using a hybrid service stub."""
301
302  def __init__(self, lang: str, stub: testing_api_pb2_grpc.HybridStub,
303               public_handle: bytes, annotations: Optional[Dict[str,
304                                                                str]]) -> None:
305    self.lang = lang
306    self._stub = stub
307    self._public_handle = public_handle
308    self._annotations = annotations
309    creation_response = self._stub.CreateHybridEncrypt(
310        testing_api_pb2.CreationRequest(
311            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
312                serialized_keyset=self._public_handle,
313                annotations=self._annotations)))
314    if creation_response.err:
315      raise tink.TinkError(creation_response.err)
316
317  def encrypt(self, plaintext: bytes, context_info: bytes) -> bytes:
318    enc_request = testing_api_pb2.HybridEncryptRequest(
319        public_annotated_keyset=testing_api_pb2.AnnotatedKeyset(
320            serialized_keyset=self._public_handle,
321            annotations=self._annotations),
322        plaintext=plaintext,
323        context_info=context_info)
324    enc_response = self._stub.Encrypt(enc_request)
325    if enc_response.err:
326      raise tink.TinkError(enc_response.err)
327    return enc_response.ciphertext
328
329
330class HybridDecrypt(hybrid.HybridDecrypt):
331  """Implements the HybridDecrypt primitive using a hybrid service stub."""
332
333  def __init__(self, lang: str, stub: testing_api_pb2_grpc.HybridStub,
334               private_handle: bytes, annotations: Optional[Dict[str,
335                                                                 str]]) -> None:
336    self.lang = lang
337    self._stub = stub
338    self._private_handle = private_handle
339    self._annotations = annotations
340    creation_response = self._stub.CreateHybridDecrypt(
341        testing_api_pb2.CreationRequest(
342            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
343                serialized_keyset=self._private_handle,
344                annotations=self._annotations)))
345    if creation_response.err:
346      raise tink.TinkError(creation_response.err)
347
348  def decrypt(self, ciphertext: bytes, context_info: bytes) -> bytes:
349    dec_request = testing_api_pb2.HybridDecryptRequest(
350        private_annotated_keyset=testing_api_pb2.AnnotatedKeyset(
351            serialized_keyset=self._private_handle,
352            annotations=self._annotations),
353        ciphertext=ciphertext,
354        context_info=context_info)
355    dec_response = self._stub.Decrypt(dec_request)
356    if dec_response.err:
357      raise tink.TinkError(dec_response.err)
358    return dec_response.plaintext
359
360
361class PublicKeySign(tink_signature.PublicKeySign):
362  """Implements the PublicKeySign primitive using a signature service stub."""
363
364  def __init__(self, lang: str, stub: testing_api_pb2_grpc.SignatureStub,
365               private_handle: bytes, annotations: Optional[Dict[str,
366                                                                 str]]) -> None:
367    self.lang = lang
368    self._stub = stub
369    self._private_handle = private_handle
370    self._annotations = annotations
371    creation_response = self._stub.CreatePublicKeySign(
372        testing_api_pb2.CreationRequest(
373            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
374                serialized_keyset=self._private_handle,
375                annotations=self._annotations)))
376    if creation_response.err:
377      raise tink.TinkError(creation_response.err)
378
379  def sign(self, data: bytes) -> bytes:
380    request = testing_api_pb2.SignatureSignRequest(
381        private_annotated_keyset=testing_api_pb2.AnnotatedKeyset(
382            serialized_keyset=self._private_handle,
383            annotations=self._annotations),
384        data=data)
385    response = self._stub.Sign(request)
386    if response.err:
387      raise tink.TinkError(response.err)
388    return response.signature
389
390
391class PublicKeyVerify(tink_signature.PublicKeyVerify):
392  """Implements the PublicKeyVerify primitive using a signature service stub."""
393
394  def __init__(self, lang: str, stub: testing_api_pb2_grpc.SignatureStub,
395               public_handle: bytes, annotations: Optional[Dict[str,
396                                                                str]]) -> None:
397    self.lang = lang
398    self._stub = stub
399    self._public_handle = public_handle
400    self._annotations = annotations
401    creation_response = self._stub.CreatePublicKeyVerify(
402        testing_api_pb2.CreationRequest(
403            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
404                serialized_keyset=self._public_handle,
405                annotations=self._annotations)))
406    if creation_response.err:
407      raise tink.TinkError(creation_response.err)
408
409  def verify(self, signature: bytes, data: bytes) -> None:  # pytype: disable=signature-mismatch  # overriding-return-type-checks
410    request = testing_api_pb2.SignatureVerifyRequest(
411        public_annotated_keyset=testing_api_pb2.AnnotatedKeyset(
412            serialized_keyset=self._public_handle,
413            annotations=self._annotations),
414        signature=signature,
415        data=data)
416    response = self._stub.Verify(request)
417    if response.err:
418      raise tink.TinkError(response.err)
419
420
421class _Prf(prf.Prf):
422  """Implements a Prf from a PrfSet service stub."""
423
424  def __init__(self, lang: str, stub: testing_api_pb2_grpc.PrfSetStub,
425               keyset: bytes, key_id: int,
426               annotations: Optional[Dict[str, str]]) -> None:
427    self.lang = lang
428    self._stub = stub
429    self._keyset = keyset
430    self._key_id = key_id
431    self._annotations = annotations
432
433  def compute(self, input_data: bytes, output_length: int) -> bytes:
434    request = testing_api_pb2.PrfSetComputeRequest(
435        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
436            serialized_keyset=self._keyset, annotations=self._annotations),
437        key_id=self._key_id,
438        input_data=input_data,
439        output_length=output_length)
440    response = self._stub.Compute(request)
441    if response.err:
442      raise tink.TinkError(response.err)
443    return response.output
444
445
446class PrfSet(prf.PrfSet):
447  """Implements a PrfSet from a PrfSet service stub."""
448
449  def __init__(self, lang: str, stub: testing_api_pb2_grpc.PrfSetStub,
450               keyset: bytes, annotations: Optional[Dict[str, str]]) -> None:
451    self.lang = lang
452    self._stub = stub
453    self._keyset = keyset
454    self._key_ids_initialized = False
455    self._primary_key_id = None
456    self._prfs = None
457    self._annotations = annotations
458    creation_response = self._stub.Create(
459        testing_api_pb2.CreationRequest(
460            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
461                serialized_keyset=self._keyset, annotations=self._annotations)))
462    if creation_response.err:
463      raise tink.TinkError(creation_response.err)
464
465  def _initialize_key_ids(self) -> None:
466    if not self._key_ids_initialized:
467      request = testing_api_pb2.PrfSetKeyIdsRequest(
468          annotated_keyset=testing_api_pb2.AnnotatedKeyset(
469              serialized_keyset=self._keyset, annotations=self._annotations))
470      response = self._stub.KeyIds(request)
471      if response.err:
472        raise tink.TinkError(response.err)
473      self._primary_key_id = response.output.primary_key_id
474      self._prfs = {}
475      for key_id in response.output.key_id:
476        self._prfs[key_id] = _Prf(self.lang, self._stub, self._keyset, key_id,
477                                  self._annotations)
478      self._key_ids_initialized = True
479
480  def primary_id(self) -> int:
481    self._initialize_key_ids()
482    return self._primary_key_id
483
484  def all(self) -> Mapping[int, prf.Prf]:
485    self._initialize_key_ids()
486    return self._prfs.copy()
487
488  def primary(self) -> prf.Prf:
489    self._initialize_key_ids()
490    return self._prfs[self._primary_key_id]
491
492
493def split_datetime(dt: datetime.datetime) -> Tuple[int, int]:
494  t = dt.timestamp()
495  seconds = int(t)
496  nanos = int((t - seconds) * 1e9)
497  return (seconds, nanos)
498
499
500def to_datetime(seconds: int, nanos: int) -> datetime.datetime:
501  t = seconds + (nanos / 1e9)
502  return datetime.datetime.fromtimestamp(t, datetime.timezone.utc)
503
504
505def raw_jwt_to_proto(raw_jwt: jwt.RawJwt) -> testing_api_pb2.JwtToken:
506  """Converts a jwt.RawJwt into a proto."""
507  raw_token = testing_api_pb2.JwtToken()
508  if raw_jwt.has_type_header():
509    raw_token.type_header.value = raw_jwt.type_header()
510  if raw_jwt.has_issuer():
511    raw_token.issuer.value = raw_jwt.issuer()
512  if raw_jwt.has_subject():
513    raw_token.subject.value = raw_jwt.subject()
514  if raw_jwt.has_audiences():
515    raw_token.audiences.extend(raw_jwt.audiences())
516  if raw_jwt.has_jwt_id():
517    raw_token.jwt_id.value = raw_jwt.jwt_id()
518  if raw_jwt.has_expiration():
519    seconds, nanos = split_datetime(raw_jwt.expiration())
520    raw_token.expiration.seconds = seconds
521    raw_token.expiration.nanos = nanos
522  if raw_jwt.has_not_before():
523    seconds, nanos = split_datetime(raw_jwt.not_before())
524    raw_token.not_before.seconds = seconds
525    raw_token.not_before.nanos = nanos
526  if raw_jwt.has_issued_at():
527    seconds, nanos = split_datetime(raw_jwt.issued_at())
528    raw_token.issued_at.seconds = seconds
529    raw_token.issued_at.nanos = nanos
530  for name in raw_jwt.custom_claim_names():
531    value = raw_jwt.custom_claim(name)
532    if value is None:
533      raw_token.custom_claims[name].null_value = testing_api_pb2.NULL_VALUE
534    if isinstance(value, (int, float)):
535      raw_token.custom_claims[name].number_value = value
536    if isinstance(value, str):
537      raw_token.custom_claims[name].string_value = value
538    if isinstance(value, bool):
539      raw_token.custom_claims[name].bool_value = value
540    if isinstance(value, dict):
541      raw_token.custom_claims[name].json_object_value = json.dumps(value)
542    if isinstance(value, list):
543      raw_token.custom_claims[name].json_array_value = json.dumps(value)
544  return raw_token
545
546
547def proto_to_verified_jwt(
548    token: testing_api_pb2.JwtToken) -> jwt.VerifiedJwt:
549  """Converts a proto JwtToken into a jwt.VerifiedJwt."""
550  type_header = None
551  if token.HasField('type_header'):
552    type_header = token.type_header.value
553  issuer = None
554  if token.HasField('issuer'):
555    issuer = token.issuer.value
556  subject = None
557  if token.HasField('subject'):
558    subject = token.subject.value
559  jwt_id = None
560  if token.HasField('jwt_id'):
561    jwt_id = token.jwt_id.value
562  audiences = None
563  if token.audiences:
564    audiences = list(token.audiences)
565  if token.HasField('expiration'):
566    expiration = to_datetime(token.expiration.seconds, token.expiration.nanos)
567    without_expiration = False
568  else:
569    expiration = None
570    without_expiration = True
571  not_before = None
572  if token.HasField('not_before'):
573    not_before = to_datetime(token.not_before.seconds, token.not_before.nanos)
574  issued_at = None
575  if token.HasField('issued_at'):
576    issued_at = to_datetime(token.issued_at.seconds, token.issued_at.nanos)
577  custom_claims = {}
578  for name in token.custom_claims:
579    value = token.custom_claims[name]
580    if value.HasField('null_value'):
581      custom_claims[name] = None
582    if value.HasField('number_value'):
583      custom_claims[name] = value.number_value
584    if value.HasField('string_value'):
585      custom_claims[name] = value.string_value
586    if value.HasField('bool_value'):
587      custom_claims[name] = value.bool_value
588    if value.HasField('json_object_value'):
589      custom_claims[name] = json.loads(value.json_object_value)
590    if value.HasField('json_array_value'):
591      custom_claims[name] = json.loads(value.json_array_value)
592  raw_jwt = jwt.new_raw_jwt(
593      type_header=type_header,
594      issuer=issuer,
595      subject=subject,
596      audiences=audiences,
597      jwt_id=jwt_id,
598      expiration=expiration,
599      without_expiration=without_expiration,
600      not_before=not_before,
601      issued_at=issued_at,
602      custom_claims=custom_claims)
603  return jwt.VerifiedJwt._create(raw_jwt)  # pylint: disable=protected-access
604
605
606def jwt_validator_to_proto(
607    validator: jwt.JwtValidator) -> testing_api_pb2.JwtValidator:
608  """Converts a jwt.JwtValidator into a proto JwtValidator."""
609  proto_validator = testing_api_pb2.JwtValidator()
610  if validator.has_expected_type_header():
611    proto_validator.expected_type_header.value = validator.expected_type_header(
612    )
613  if validator.has_expected_issuer():
614    proto_validator.expected_issuer.value = validator.expected_issuer()
615  if validator.has_expected_audience():
616    proto_validator.expected_audience.value = validator.expected_audience()
617  proto_validator.ignore_type_header = validator.ignore_type_header()
618  proto_validator.ignore_issuer = validator.ignore_issuer()
619  proto_validator.ignore_audience = validator.ignore_audiences()
620  proto_validator.allow_missing_expiration = validator.allow_missing_expiration(
621  )
622  proto_validator.expect_issued_in_the_past = (
623      validator.expect_issued_in_the_past())
624  proto_validator.clock_skew.seconds = validator.clock_skew().seconds
625  if validator.has_fixed_now():
626    seconds, nanos = split_datetime(validator.fixed_now())
627    proto_validator.now.seconds = seconds
628    proto_validator.now.nanos = nanos
629  return proto_validator
630
631
632class JwtMac():
633  """Implements a JwtMac from a Jwt service stub."""
634
635  def __init__(self, lang: str, stub: testing_api_pb2_grpc.JwtStub,
636               keyset: bytes) -> None:
637    self.lang = lang
638    self._stub = stub
639    self._keyset = keyset
640    creation_response = self._stub.CreateJwtMac(
641        testing_api_pb2.CreationRequest(
642            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
643                serialized_keyset=self._keyset)))
644    if creation_response.err:
645      raise tink.TinkError(creation_response.err)
646
647  def compute_mac_and_encode(self, raw_jwt: jwt.RawJwt) -> str:
648    request = testing_api_pb2.JwtSignRequest(
649        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
650            serialized_keyset=self._keyset),
651        raw_jwt=raw_jwt_to_proto(raw_jwt))
652    response = self._stub.ComputeMacAndEncode(request)
653    if response.err:
654      raise tink.TinkError(response.err)
655    return response.signed_compact_jwt
656
657  def verify_mac_and_decode(self, signed_compact_jwt: str,
658                            validator: jwt.JwtValidator) -> jwt.VerifiedJwt:
659    """verifies and decodes a jwt in compact serialization using a mac.
660
661    Args:
662      signed_compact_jwt: the sign jwt in compact serialization form.
663      validator: validator to validate the jwt.
664
665    Returns:
666
667    Raises:
668      tink.TinkError: if verification or validation fails.
669    """
670    request = testing_api_pb2.JwtVerifyRequest(
671        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
672            serialized_keyset=self._keyset),
673        validator=jwt_validator_to_proto(validator),
674        signed_compact_jwt=signed_compact_jwt)
675    response = self._stub.VerifyMacAndDecode(request)
676    if response.err:
677      raise tink.TinkError(response.err)
678    return proto_to_verified_jwt(response.verified_jwt)
679
680
681class JwtPublicKeySign():
682  """Implements a JwtPublicKeySign from a Jwt service stub."""
683
684  def __init__(self, lang: str, stub: testing_api_pb2_grpc.JwtStub,
685               keyset: bytes) -> None:
686    self.lang = lang
687    self._stub = stub
688    self._keyset = keyset
689    creation_response = self._stub.CreateJwtPublicKeySign(
690        testing_api_pb2.CreationRequest(
691            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
692                serialized_keyset=self._keyset)))
693    if creation_response.err:
694      raise tink.TinkError(creation_response.err)
695
696  def sign_and_encode(self, raw_jwt: jwt.RawJwt) -> str:
697    request = testing_api_pb2.JwtSignRequest(
698        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
699            serialized_keyset=self._keyset),
700        raw_jwt=raw_jwt_to_proto(raw_jwt))
701    response = self._stub.PublicKeySignAndEncode(request)
702    if response.err:
703      raise tink.TinkError(response.err)
704    return response.signed_compact_jwt
705
706
707class JwtPublicKeyVerify():
708  """Implements a JwtPublicKeyVerify from a Jwt service stub."""
709
710  def __init__(self, lang: str, stub: testing_api_pb2_grpc.JwtStub,
711               keyset: bytes) -> None:
712    self.lang = lang
713    self._stub = stub
714    self._keyset = keyset
715    creation_response = self._stub.CreateJwtPublicKeyVerify(
716        testing_api_pb2.CreationRequest(
717            annotated_keyset=testing_api_pb2.AnnotatedKeyset(
718                serialized_keyset=self._keyset)))
719    if creation_response.err:
720      raise tink.TinkError(creation_response.err)
721
722  def verify_and_decode(self, signed_compact_jwt: str,
723                        validator: jwt.JwtValidator) -> jwt.VerifiedJwt:
724    """verifies and decodes a jwt in compact serialization using a digital signature.
725
726    Args:
727      signed_compact_jwt: the sign jwt in compact serialization form.
728      validator: validator to validate the jwt.
729
730    Returns:
731
732    Raises:
733      tink.TinkError: if verification or validation fails.
734    """
735    request = testing_api_pb2.JwtVerifyRequest(
736        annotated_keyset=testing_api_pb2.AnnotatedKeyset(
737            serialized_keyset=self._keyset),
738        validator=jwt_validator_to_proto(validator),
739        signed_compact_jwt=signed_compact_jwt)
740    response = self._stub.PublicKeyVerifyAndDecode(request)
741    if response.err:
742      raise tink.TinkError(response.err)
743    return proto_to_verified_jwt(response.verified_jwt)
744