xref: /aosp_15_r20/external/tink/python/tink/testing/helper.py (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""This class implements helper functions for testing."""
16
17import os
18from typing import Mapping
19
20from tink.proto import tink_pb2
21from tink import aead
22from tink import core
23from tink import daead
24from tink import hybrid
25from tink import mac
26from tink import prf
27from tink import signature as pk_signature
28
29_RELATIVE_TESTDATA_PATH = 'tink_py/testdata'
30
31
32def tink_py_testdata_path() -> str:
33  """Returns the path to the test data directory to be used for testing."""
34  # List of pairs <Env. variable, Path>.
35  testdata_paths = []
36  if 'TINK_PYTHON_ROOT_PATH' in os.environ:
37    testdata_paths.append(('TINK_PYTHON_ROOT_PATH',
38                           os.path.join(os.environ['TINK_PYTHON_ROOT_PATH'],
39                                        'testdata')))
40  if 'TEST_SRCDIR' in os.environ:
41    testdata_paths.append(('TEST_SRCDIR',
42                           os.path.join(os.environ['TEST_SRCDIR'],
43                                        _RELATIVE_TESTDATA_PATH)))
44  for env_variable, testdata_path in testdata_paths:
45    # Return the first path that is encountered.
46    if not os.path.exists(testdata_path):
47      raise FileNotFoundError(f'Variable {env_variable} is set but has an ' +
48                              f'invalid path {testdata_path}')
49    return testdata_path
50  raise ValueError('No path environment variable set among ' +
51                   'TINK_PYTHON_ROOT_PATH, TEST_SRCDIR')
52
53
54def fake_key(
55    value: bytes = b'fakevalue',
56    type_url: str = 'fakeurl',
57    key_material_type: tink_pb2.KeyData.KeyMaterialType = tink_pb2.KeyData
58    .SYMMETRIC,
59    key_id: int = 1234,
60    status: tink_pb2.KeyStatusType = tink_pb2.ENABLED,
61    output_prefix_type: tink_pb2.OutputPrefixType = tink_pb2.TINK
62) -> tink_pb2.Keyset.Key:
63  """Returns a fake but valid key."""
64  key = tink_pb2.Keyset.Key(
65      key_id=key_id,
66      status=status,
67      output_prefix_type=output_prefix_type)
68  key.key_data.type_url = type_url
69  key.key_data.value = value
70  key.key_data.key_material_type = key_material_type
71  return key
72
73
74class FakeMac(mac.Mac):
75  """A fake MAC implementation."""
76
77  def __init__(self, name: str = 'FakeMac'):
78    self._name = name
79
80  def compute_mac(self, data: bytes) -> bytes:
81    return data + b'|' + self._name.encode()
82
83  def verify_mac(self, mac_value: bytes, data: bytes) -> None:
84    if mac_value != data + b'|' + self._name.encode():
85      raise core.TinkError('invalid mac ' + mac_value.decode())
86
87
88class FakeAead(aead.Aead):
89  """A fake AEAD implementation."""
90
91  def __init__(self, name: str = 'FakeAead'):
92    self._name = name
93
94  def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
95    return plaintext + b'|' + associated_data + b'|' + self._name.encode()
96
97  def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
98    data = ciphertext.split(b'|')
99    if (len(data) < 3 or data[1] != associated_data or
100        data[2] != self._name.encode()):
101      raise core.TinkError('failed to decrypt ciphertext ' +
102                           ciphertext.decode())
103    return data[0]
104
105
106class FakeDeterministicAead(daead.DeterministicAead):
107  """A fake Deterministic AEAD implementation."""
108
109  def __init__(self, name: str = 'FakeDeterministicAead'):
110    self._name = name
111
112  def encrypt_deterministically(self, plaintext: bytes,
113                                associated_data: bytes) -> bytes:
114    return plaintext + b'|' + associated_data + b'|' + self._name.encode()
115
116  def decrypt_deterministically(self, ciphertext: bytes,
117                                associated_data: bytes) -> bytes:
118    data = ciphertext.split(b'|')
119    if (len(data) < 3 or
120        data[1] != associated_data or
121        data[2] != self._name.encode()):
122      raise core.TinkError('failed to decrypt ciphertext ' +
123                           ciphertext.decode())
124    return data[0]
125
126
127class FakeHybridDecrypt(hybrid.HybridDecrypt):
128  """A fake HybridEncrypt implementation."""
129
130  def __init__(self, name: str = 'Hybrid'):
131    self._name = name
132
133  def decrypt(self, ciphertext: bytes, context_info: bytes) -> bytes:
134    data = ciphertext.split(b'|')
135    if (len(data) < 3 or
136        data[1] != context_info or
137        data[2] != self._name.encode()):
138      raise core.TinkError('failed to decrypt ciphertext ' +
139                           ciphertext.decode())
140    return data[0]
141
142
143class FakeHybridEncrypt(hybrid.HybridEncrypt):
144  """A fake HybridEncrypt implementation."""
145
146  def __init__(self, name: str = 'Hybrid'):
147    self._name = name
148
149  def encrypt(self, plaintext: bytes, context_info: bytes) -> bytes:
150    return plaintext + b'|' + context_info + b'|' + self._name.encode()
151
152
153class FakePublicKeySign(pk_signature.PublicKeySign):
154  """A fake PublicKeySign implementation."""
155
156  def __init__(self, name: str = 'FakePublicKeySign'):
157    self._name = name
158
159  def sign(self, data: bytes) -> bytes:
160    return data + b'|' + self._name.encode()
161
162
163class FakePublicKeyVerify(pk_signature.PublicKeyVerify):
164  """A fake PublicKeyVerify implementation."""
165
166  def __init__(self, name: str = 'FakePublicKeyVerify'):
167    self._name = name
168
169  def verify(self, signature: bytes, data: bytes):
170    if signature != data + b'|' + self._name.encode():
171      raise core.TinkError('invalid signature ' + signature.decode())
172
173
174class FakePrf(prf.Prf):
175  """A fake Prf implementation."""
176
177  def __init__(self, name: str = 'FakePrf'):
178    self._name = name
179
180  def compute(self, input_data: bytes, output_length: int) -> bytes:
181    if output_length > 32:
182      raise core.TinkError('invalid output_length')
183    output = (
184        input_data + b'|' + self._name.encode() + b'|' +
185        b''.join([b'*' for _ in range(output_length)]))
186    return output[:output_length]
187
188
189class FakePrfSet(prf.PrfSet):
190  """A fake PrfSet implementation that contains exactly one Prf."""
191
192  def __init__(self, name: str = 'FakePrf'):
193    self._prf = FakePrf(name)
194
195  def primary_id(self) -> int:
196    return 0
197
198  def all(self) -> Mapping[int, prf.Prf]:
199    return {0: self._prf}
200
201  def primary(self) -> prf.Prf:
202    return self._prf
203