1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import binascii
16import io
17import random
18
19from absl.testing import absltest
20from absl.testing import parameterized
21import tink
22from tink import streaming_aead
23
24from tink.proto import aes_gcm_hkdf_streaming_pb2
25from tink.proto import common_pb2
26from tink.proto import tink_pb2
27import tink_config
28from util import testing_servers
29
30
31def setUpModule():
32  streaming_aead.register()
33  testing_servers.start('aes_gcm_hkdf_streaming_key_test')
34
35
36def tearDownModule():
37  testing_servers.stop()
38
39
40def to_keyset(
41    key: aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey,
42) -> tink_pb2.Keyset:
43  """Embeds a AesGcmHkdfStreamingKey in some way in a keyset."""
44  return tink_pb2.Keyset(
45      primary_key_id=1234,
46      key=[
47          tink_pb2.Keyset.Key(
48              key_data=tink_pb2.KeyData(
49                  type_url='type.googleapis.com/google.crypto.tink.AesGcmHkdfStreamingKey',
50                  value=key.SerializeToString(),
51                  key_material_type='SYMMETRIC',
52              ),
53              output_prefix_type=tink_pb2.OutputPrefixType.RAW,
54              status=tink_pb2.KeyStatusType.ENABLED,
55              key_id=1234,
56          )
57      ],
58  )
59
60
61def simple_valid_key() -> (
62    aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey
63):
64  """Creates a simple, valid AesGcmHkdfStreamingKey object."""
65  return aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey(
66      version=0,
67      params=aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingParams(
68          ciphertext_segment_size=512,
69          derived_key_size=16,
70          hkdf_hash_type=common_pb2.HashType.SHA256,
71      ),
72      key_value=b'0123456789abcdef',
73  )
74
75
76def lang_and_valid_keys_create_and_encrypt():
77  result = []
78  langs = tink_config.supported_languages_for_key_type('AesGcmHkdfStreamingKey')
79
80  key = simple_valid_key()
81  for lang in langs:
82    result.append((lang, key))
83
84  key = simple_valid_key()
85  assert key.params.derived_key_size == 16
86  key.params.derived_key_size = 32
87  key.key_value = b'0123456789abcdef0123456789abcdef'
88  for lang in langs:
89    result.append((lang, key))
90
91  # HKDF Hash Type:
92  key = simple_valid_key()
93  key.params.hkdf_hash_type = common_pb2.HashType.SHA1
94  for lang in langs:
95    result.append((lang, key))
96
97  key = simple_valid_key()
98  key.params.hkdf_hash_type = common_pb2.HashType.SHA256
99  for lang in langs:
100    result.append((lang, key))
101
102  key = simple_valid_key()
103  key.params.hkdf_hash_type = common_pb2.HashType.SHA512
104  for lang in langs:
105    result.append((lang, key))
106
107  # Minimum ciphertext_segment_size
108  key = simple_valid_key()
109  key.params.ciphertext_segment_size = key.params.derived_key_size + 25
110  for lang in langs:
111    result.append((lang, key))
112
113  return result
114
115
116def lang_and_valid_keys_create_only():
117  result = lang_and_valid_keys_create_and_encrypt()
118  langs = tink_config.supported_languages_for_key_type('AesGcmHkdfStreamingKey')
119
120  # TODO(b/268193523): Java crashes with ciphertext_segment_size = 2**31 - 1
121  key = simple_valid_key()
122  key.params.ciphertext_segment_size = 2**31 - 1
123  for lang in langs:
124    result.append((lang, key))
125
126  return result
127
128
129def lang_and_invalid_keys():
130  result = []
131  langs = tink_config.supported_languages_for_key_type('AesGcmHkdfStreamingKey')
132
133  key = simple_valid_key()
134  key.params.derived_key_size = 24
135  for lang in langs:
136    result.append((lang, key))
137
138  key = simple_valid_key()
139  key.params.hkdf_hash_type = common_pb2.HashType.SHA224
140  for lang in langs:
141    result.append((lang, key))
142
143  key = simple_valid_key()
144  key.params.hkdf_hash_type = common_pb2.HashType.SHA384
145  for lang in langs:
146    result.append((lang, key))
147
148  # Check requirement len(InitialKeyMaterial) >= DerivedKeySize
149  key = simple_valid_key()
150  key.key_value = b'0123456789abcdef'
151  key.params.derived_key_size = 32
152  for lang in langs:
153    result.append((lang, key))
154
155  # HKDF Hash Type:
156  key = simple_valid_key()
157  key.params.hkdf_hash_type = common_pb2.HashType.UNKNOWN_HASH
158  for lang in langs:
159    result.append((lang, key))
160
161  # Minimum ciphertext_segment_size
162  key = simple_valid_key()
163  key.params.ciphertext_segment_size = key.params.derived_key_size + 24
164  for lang in langs:
165    result.append((lang, key))
166
167  key = simple_valid_key()
168  key.params.ciphertext_segment_size = 2**31
169  for lang in langs:
170    result.append((lang, key))
171
172  return result
173
174
175class AesGcmHkdfStreamingKeyTest(parameterized.TestCase):
176  """Tests specific for keys of type AesGcmHkdfStreamingKey.
177
178  See https://developers.google.com/tink/streaming-aead/aes_gcm_hkdf_streaming
179  for the documentation.
180  """
181
182  @parameterized.parameters(lang_and_valid_keys_create_only())
183  def test_create_streaming_aead(
184      self, lang: str, key: aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey
185  ):
186    keyset = to_keyset(key)
187    testing_servers.remote_primitive(
188        lang, keyset.SerializeToString(), streaming_aead.StreamingAead
189    )
190
191  @parameterized.parameters(lang_and_valid_keys_create_and_encrypt())
192  def test_create_streaming_aead_encrypt_decrypt(
193      self, lang: str, key: aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey
194  ):
195    keyset = to_keyset(key)
196    saead = testing_servers.remote_primitive(
197        lang, keyset.SerializeToString(), streaming_aead.StreamingAead
198    )
199    plaintext = b'some plaintext'
200    ad = b'associated_data'
201    ciphertext = saead.new_encrypting_stream(
202        io.BytesIO(plaintext), ad
203    ).read()
204    self.assertEqual(
205        saead.new_decrypting_stream(
206            io.BytesIO(ciphertext), ad
207        ).read(),
208        plaintext,
209    )
210
211  @parameterized.parameters(lang_and_invalid_keys())
212  def test_create_streaming_aead_invalid_key_fails(
213      self, lang: str, key: aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey
214  ):
215    keyset = to_keyset(key)
216    with self.assertRaises(tink.TinkError):
217      testing_servers.remote_primitive(
218          lang, keyset.SerializeToString(), streaming_aead.StreamingAead
219      )
220
221  def test_output_prefix_ignored(self):
222    lang_1 = random.choice(
223        tink_config.supported_languages_for_key_type('AesGcmHkdfStreamingKey')
224    )
225    lang_2 = random.choice(
226        tink_config.supported_languages_for_key_type('AesGcmHkdfStreamingKey')
227    )
228    output_prefix_1 = random.choice(
229        [tink_pb2.RAW, tink_pb2.CRUNCHY, tink_pb2.LEGACY, tink_pb2.TINK]
230    )
231    output_prefix_2 = random.choice(
232        [tink_pb2.RAW, tink_pb2.CRUNCHY, tink_pb2.LEGACY, tink_pb2.TINK]
233    )
234    with self.subTest(
235        f'Testing with languages ({lang_1}, {lang_2}) and output prefix types '
236        f'({tink_pb2.OutputPrefixType.Name(output_prefix_1)}, '
237        f'{tink_pb2.OutputPrefixType.Name(output_prefix_2)})'
238    ):
239      keyset = to_keyset(simple_valid_key())
240      keyset.key[0].output_prefix_type = output_prefix_1
241      saead_1 = testing_servers.remote_primitive(
242          lang_1, keyset.SerializeToString(), streaming_aead.StreamingAead
243      )
244      keyset.key[0].output_prefix_type = output_prefix_2
245      saead_2 = testing_servers.remote_primitive(
246          lang_2, keyset.SerializeToString(), streaming_aead.StreamingAead
247      )
248      plaintext = b'some plaintext'
249      associated_data = b'associated_data'
250      ciphertext = saead_1.new_encrypting_stream(
251          io.BytesIO(plaintext), associated_data
252      ).read()
253      self.assertEqual(
254          saead_2.new_decrypting_stream(
255              io.BytesIO(ciphertext), associated_data
256          ).read(),
257          plaintext,
258      )
259
260  @parameterized.parameters(
261      tink_config.supported_languages_for_key_type('AesCtrHmacStreamingKey')
262  )
263  def test_manually_created_test_vector(self, lang: str):
264    """This test uses a ciphertext created by looking at the documentation.
265
266    See https://developers.google.com/tink/streaming-aead/aes_gcm_hkdf_streaming
267    for the documentation. The goal is to ensure that the documentation is
268    clear; we expect readers to read this with the documentation.
269
270    Args:
271      lang: the language to test
272    """
273
274    h2b = binascii.a2b_hex
275
276    key = aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingKey(
277        version=0,
278        params=aes_gcm_hkdf_streaming_pb2.AesGcmHkdfStreamingParams(
279            ciphertext_segment_size=64,
280            derived_key_size=16,
281            hkdf_hash_type=common_pb2.HashType.SHA1,
282        ),
283        key_value=h2b('6eb56cdc726dfbe5d57f2fcdc6e9345b')
284    )
285    # We set the message to be:
286    msg = (
287        b'This is a fairly long plaintext. '
288        + b'It is of the exact length to create three output blocks. '
289    )
290    #
291    # We set the associated data to be:
292    associated_data = b'aad'
293
294    # We picked the header at random: Note the length is 24 = 0x18.
295    header_length = h2b('18')
296    salt = h2b('93b3af5e14ab378d065addfc8484da64')
297    nonce_prefix = h2b('2c0862877baea8')
298    header = header_length + salt + nonce_prefix
299    # hkdf.hkdf_sha1(ikm=key_value, salt=salt, info=aad, size=16) gives
300    # '66dd511791296a6cfc94a24041fcab9f'
301    # aes_key = h2b('66dd511791296a6cfc94a24041fcab9f')
302
303    # We next split the message:
304    # len(msg) = 90
305    # len(M_0) = 24 = CiphertextSegmentSize(64) - Headerlength(24) - 16
306    # len(M_1) = 48 = CiphertextSegmentSize(64) - 16
307    # len(M_2) = 18 < CiphertextSegmentSize(64) - 16
308    # msg_0 = msg[:24]
309    # msg_1 = msg[24:72]
310    # msg_2 = msg[72:]
311    #
312    # AES GCM computations with key = 66dd511791296a6cfc94a24041fcab9f
313    #
314    #
315    # IV = nonce_prefix + segment_nr + b | plaintext | result
316    # -----------------------------------------------------------------------
317    # 2c0862877baea8 00000000 00         | msg_0     | c_0
318    # 2c0862877baea8 00000001 00         | msg_1     | c_0
319    # 2c0862877baea8 00000002 01         | msg_2     | c_0
320    c0 = h2b(
321        b'db92d9c77406a406168478821c4298eab3e6d531277f4c1a'
322        + b'051714faebcaefcbca7b7be05e9445ea'
323    )
324    c1 = h2b(
325        b'a0bb2904153398a25084dd80ae0edcd1c3079fcea2cd3770'
326        + b'630ee36f7539207b8ec9d754956d486b71cdf989f0ed6fba'
327        + b'6779b63558be0a66e668df14e1603cd2'
328    )
329    c2 = h2b(
330        b'af8944844078345286d0b292e772e7190775'
331        + b'c51a0f83e40c0b75821027e7e538e111'
332    )
333
334    ciphertext = header + c0 + c1 + c2
335
336    keyset = to_keyset(key)
337    saead = testing_servers.remote_primitive(
338        lang, keyset.SerializeToString(), streaming_aead.StreamingAead
339    )
340
341    self.assertEqual(
342        saead.new_decrypting_stream(
343            io.BytesIO(ciphertext), associated_data
344        ).read(),
345        msg,
346    )
347
348
349if __name__ == '__main__':
350  absltest.main()
351