1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""This class implements helper functions for testing.""" 16 17import os 18from typing import Mapping 19 20from tink.proto import tink_pb2 21from tink import aead 22from tink import core 23from tink import daead 24from tink import hybrid 25from tink import mac 26from tink import prf 27from tink import signature as pk_signature 28 29_RELATIVE_TESTDATA_PATH = 'tink_py/testdata' 30 31 32def tink_py_testdata_path() -> str: 33 """Returns the path to the test data directory to be used for testing.""" 34 # List of pairs <Env. variable, Path>. 35 testdata_paths = [] 36 if 'TINK_PYTHON_ROOT_PATH' in os.environ: 37 testdata_paths.append(('TINK_PYTHON_ROOT_PATH', 38 os.path.join(os.environ['TINK_PYTHON_ROOT_PATH'], 39 'testdata'))) 40 if 'TEST_SRCDIR' in os.environ: 41 testdata_paths.append(('TEST_SRCDIR', 42 os.path.join(os.environ['TEST_SRCDIR'], 43 _RELATIVE_TESTDATA_PATH))) 44 for env_variable, testdata_path in testdata_paths: 45 # Return the first path that is encountered. 46 if not os.path.exists(testdata_path): 47 raise FileNotFoundError(f'Variable {env_variable} is set but has an ' + 48 f'invalid path {testdata_path}') 49 return testdata_path 50 raise ValueError('No path environment variable set among ' + 51 'TINK_PYTHON_ROOT_PATH, TEST_SRCDIR') 52 53 54def fake_key( 55 value: bytes = b'fakevalue', 56 type_url: str = 'fakeurl', 57 key_material_type: tink_pb2.KeyData.KeyMaterialType = tink_pb2.KeyData 58 .SYMMETRIC, 59 key_id: int = 1234, 60 status: tink_pb2.KeyStatusType = tink_pb2.ENABLED, 61 output_prefix_type: tink_pb2.OutputPrefixType = tink_pb2.TINK 62) -> tink_pb2.Keyset.Key: 63 """Returns a fake but valid key.""" 64 key = tink_pb2.Keyset.Key( 65 key_id=key_id, 66 status=status, 67 output_prefix_type=output_prefix_type) 68 key.key_data.type_url = type_url 69 key.key_data.value = value 70 key.key_data.key_material_type = key_material_type 71 return key 72 73 74class FakeMac(mac.Mac): 75 """A fake MAC implementation.""" 76 77 def __init__(self, name: str = 'FakeMac'): 78 self._name = name 79 80 def compute_mac(self, data: bytes) -> bytes: 81 return data + b'|' + self._name.encode() 82 83 def verify_mac(self, mac_value: bytes, data: bytes) -> None: 84 if mac_value != data + b'|' + self._name.encode(): 85 raise core.TinkError('invalid mac ' + mac_value.decode()) 86 87 88class FakeAead(aead.Aead): 89 """A fake AEAD implementation.""" 90 91 def __init__(self, name: str = 'FakeAead'): 92 self._name = name 93 94 def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: 95 return plaintext + b'|' + associated_data + b'|' + self._name.encode() 96 97 def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: 98 data = ciphertext.split(b'|') 99 if (len(data) < 3 or data[1] != associated_data or 100 data[2] != self._name.encode()): 101 raise core.TinkError('failed to decrypt ciphertext ' + 102 ciphertext.decode()) 103 return data[0] 104 105 106class FakeDeterministicAead(daead.DeterministicAead): 107 """A fake Deterministic AEAD implementation.""" 108 109 def __init__(self, name: str = 'FakeDeterministicAead'): 110 self._name = name 111 112 def encrypt_deterministically(self, plaintext: bytes, 113 associated_data: bytes) -> bytes: 114 return plaintext + b'|' + associated_data + b'|' + self._name.encode() 115 116 def decrypt_deterministically(self, ciphertext: bytes, 117 associated_data: bytes) -> bytes: 118 data = ciphertext.split(b'|') 119 if (len(data) < 3 or 120 data[1] != associated_data or 121 data[2] != self._name.encode()): 122 raise core.TinkError('failed to decrypt ciphertext ' + 123 ciphertext.decode()) 124 return data[0] 125 126 127class FakeHybridDecrypt(hybrid.HybridDecrypt): 128 """A fake HybridEncrypt implementation.""" 129 130 def __init__(self, name: str = 'Hybrid'): 131 self._name = name 132 133 def decrypt(self, ciphertext: bytes, context_info: bytes) -> bytes: 134 data = ciphertext.split(b'|') 135 if (len(data) < 3 or 136 data[1] != context_info or 137 data[2] != self._name.encode()): 138 raise core.TinkError('failed to decrypt ciphertext ' + 139 ciphertext.decode()) 140 return data[0] 141 142 143class FakeHybridEncrypt(hybrid.HybridEncrypt): 144 """A fake HybridEncrypt implementation.""" 145 146 def __init__(self, name: str = 'Hybrid'): 147 self._name = name 148 149 def encrypt(self, plaintext: bytes, context_info: bytes) -> bytes: 150 return plaintext + b'|' + context_info + b'|' + self._name.encode() 151 152 153class FakePublicKeySign(pk_signature.PublicKeySign): 154 """A fake PublicKeySign implementation.""" 155 156 def __init__(self, name: str = 'FakePublicKeySign'): 157 self._name = name 158 159 def sign(self, data: bytes) -> bytes: 160 return data + b'|' + self._name.encode() 161 162 163class FakePublicKeyVerify(pk_signature.PublicKeyVerify): 164 """A fake PublicKeyVerify implementation.""" 165 166 def __init__(self, name: str = 'FakePublicKeyVerify'): 167 self._name = name 168 169 def verify(self, signature: bytes, data: bytes): 170 if signature != data + b'|' + self._name.encode(): 171 raise core.TinkError('invalid signature ' + signature.decode()) 172 173 174class FakePrf(prf.Prf): 175 """A fake Prf implementation.""" 176 177 def __init__(self, name: str = 'FakePrf'): 178 self._name = name 179 180 def compute(self, input_data: bytes, output_length: int) -> bytes: 181 if output_length > 32: 182 raise core.TinkError('invalid output_length') 183 output = ( 184 input_data + b'|' + self._name.encode() + b'|' + 185 b''.join([b'*' for _ in range(output_length)])) 186 return output[:output_length] 187 188 189class FakePrfSet(prf.PrfSet): 190 """A fake PrfSet implementation that contains exactly one Prf.""" 191 192 def __init__(self, name: str = 'FakePrf'): 193 self._prf = FakePrf(name) 194 195 def primary_id(self) -> int: 196 return 0 197 198 def all(self) -> Mapping[int, prf.Prf]: 199 return {0: self._prf} 200 201 def primary(self) -> prf.Prf: 202 return self._prf 203