1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import binascii
16import io
17import random
18
19from absl.testing import absltest
20from absl.testing import parameterized
21import tink
22from tink import streaming_aead
23
24from tink.proto import aes_ctr_hmac_streaming_pb2
25from tink.proto import common_pb2
26from tink.proto import hmac_pb2
27from tink.proto import tink_pb2
28import tink_config
29from util import testing_servers
30
31
32def setUpModule():
33  streaming_aead.register()
34  testing_servers.start('aes_ctr_hmac_streaming_key_test')
35
36
37def tearDownModule():
38  testing_servers.stop()
39
40
41def to_keyset(
42    key: aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey,
43) -> tink_pb2.Keyset:
44  """Embeds a AesCtrHmacStreamingKey in some way in a keyset."""
45  return tink_pb2.Keyset(
46      primary_key_id=1234,
47      key=[
48          tink_pb2.Keyset.Key(
49              key_data=tink_pb2.KeyData(
50                  type_url='type.googleapis.com/google.crypto.tink.AesCtrHmacStreamingKey',
51                  value=key.SerializeToString(),
52                  key_material_type='SYMMETRIC',
53              ),
54              output_prefix_type=tink_pb2.OutputPrefixType.RAW,
55              status=tink_pb2.KeyStatusType.ENABLED,
56              key_id=1234,
57          )
58      ],
59  )
60
61
62def simple_valid_key() -> (
63    aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey
64):
65  """Creates a simple, valid AesCtrHmacStreamingKey object."""
66  return aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey(
67      version=0,
68      params=aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingParams(
69          ciphertext_segment_size=512,
70          derived_key_size=16,
71          hkdf_hash_type=common_pb2.HashType.SHA256,
72          hmac_params=hmac_pb2.HmacParams(
73              hash=common_pb2.HashType.SHA256, tag_size=16
74          ),
75      ),
76      key_value=b'0123456789abcdef',
77  )
78
79
80def lang_and_valid_keys_create_and_encrypt():
81  result = []
82  langs = tink_config.supported_languages_for_key_type('AesCtrHmacStreamingKey')
83
84  key = simple_valid_key()
85  for lang in langs:
86    result.append((lang, key))
87
88  key = simple_valid_key()
89  assert key.params.derived_key_size == 16
90  key.params.derived_key_size = 32
91  key.key_value = b'0123456789abcdef0123456789abcdef'
92  for lang in langs:
93    result.append((lang, key))
94
95  ## TAG SIZES
96  key = simple_valid_key()
97  key.params.hmac_params.hash = common_pb2.HashType.SHA1
98  key.params.hmac_params.tag_size = 10
99  for lang in langs:
100    result.append((lang, key))
101
102  key = simple_valid_key()
103  key.params.hmac_params.hash = common_pb2.HashType.SHA1
104  key.params.hmac_params.tag_size = 11
105  for lang in langs:
106    result.append((lang, key))
107
108  key = simple_valid_key()
109  key.params.hmac_params.hash = common_pb2.HashType.SHA1
110  key.params.hmac_params.tag_size = 20
111  for lang in langs:
112    result.append((lang, key))
113
114  key = simple_valid_key()
115  key.params.hmac_params.hash = common_pb2.HashType.SHA256
116  key.params.hmac_params.tag_size = 10
117  for lang in langs:
118    result.append((lang, key))
119
120  key = simple_valid_key()
121  key.params.hmac_params.hash = common_pb2.HashType.SHA256
122  key.params.hmac_params.tag_size = 11
123  for lang in langs:
124    result.append((lang, key))
125
126  key = simple_valid_key()
127  key.params.hmac_params.hash = common_pb2.HashType.SHA256
128  key.params.hmac_params.tag_size = 32
129  for lang in langs:
130    result.append((lang, key))
131
132  key = simple_valid_key()
133  key.params.hmac_params.hash = common_pb2.HashType.SHA512
134  key.params.hmac_params.tag_size = 10
135  for lang in langs:
136    result.append((lang, key))
137
138  key = simple_valid_key()
139  key.params.hmac_params.hash = common_pb2.HashType.SHA512
140  key.params.hmac_params.tag_size = 11
141  for lang in langs:
142    result.append((lang, key))
143
144  key = simple_valid_key()
145  key.params.hmac_params.hash = common_pb2.HashType.SHA512
146  key.params.hmac_params.tag_size = 64
147  for lang in langs:
148    result.append((lang, key))
149
150  # HKDF Hash Type:
151  key = simple_valid_key()
152  key.params.hkdf_hash_type = common_pb2.HashType.SHA1
153  for lang in langs:
154    result.append((lang, key))
155
156  key = simple_valid_key()
157  key.params.hkdf_hash_type = common_pb2.HashType.SHA256
158  for lang in langs:
159    result.append((lang, key))
160
161  key = simple_valid_key()
162  key.params.hkdf_hash_type = common_pb2.HashType.SHA512
163  for lang in langs:
164    result.append((lang, key))
165
166  # Minimum ciphertext_segment_size
167  key = simple_valid_key()
168  key.params.ciphertext_segment_size = (
169      key.params.derived_key_size + key.params.hmac_params.tag_size + 9
170  )
171  for lang in langs:
172    result.append((lang, key))
173
174  return result
175
176
177def lang_and_valid_keys_create_only():
178  result = lang_and_valid_keys_create_and_encrypt()
179  langs = tink_config.supported_languages_for_key_type('AesCtrHmacStreamingKey')
180
181  # TODO(b/268193523): Java crashes with ciphertext_segment_size = 2**31 - 1
182  key = simple_valid_key()
183  key.params.ciphertext_segment_size = 2**31 - 1
184  for lang in langs:
185    result.append((lang, key))
186
187  return result
188
189
190def lang_and_invalid_keys():
191  result = []
192  langs = tink_config.supported_languages_for_key_type('AesCtrHmacStreamingKey')
193
194  key = simple_valid_key()
195  key.params.derived_key_size = 24
196  for lang in langs:
197    result.append((lang, key))
198
199  key = simple_valid_key()
200  key.params.hkdf_hash_type = common_pb2.HashType.SHA224
201  for lang in langs:
202    result.append((lang, key))
203
204  key = simple_valid_key()
205  key.params.hkdf_hash_type = common_pb2.HashType.SHA384
206  for lang in langs:
207    result.append((lang, key))
208
209  # Check requirement len(InitialKeyMaterial) >= DerivedKeySize
210  key = simple_valid_key()
211  key.key_value = b'0123456789abcdef'
212  key.params.derived_key_size = 32
213  for lang in langs:
214    result.append((lang, key))
215
216  # HKDF Hash Type:
217  key = simple_valid_key()
218  key.params.hkdf_hash_type = common_pb2.HashType.UNKNOWN_HASH
219  for lang in langs:
220    result.append((lang, key))
221
222  # Minimum ciphertext_segment_size
223  key = simple_valid_key()
224  key.params.ciphertext_segment_size = (
225      key.params.derived_key_size + key.params.hmac_params.tag_size + 8
226  )
227  for lang in langs:
228    result.append((lang, key))
229
230  ## Tag sizes
231  key = simple_valid_key()
232  key.params.hmac_params.hash = common_pb2.HashType.SHA1
233  key.params.hmac_params.tag_size = 9
234  for lang in langs:
235    result.append((lang, key))
236
237  key = simple_valid_key()
238  key.params.hmac_params.hash = common_pb2.HashType.SHA1
239  key.params.hmac_params.tag_size = 21
240  for lang in langs:
241    result.append((lang, key))
242
243  key = simple_valid_key()
244  key.params.hmac_params.hash = common_pb2.HashType.SHA256
245  key.params.hmac_params.tag_size = 9
246  for lang in langs:
247    result.append((lang, key))
248
249  key = simple_valid_key()
250  key.params.hmac_params.hash = common_pb2.HashType.SHA256
251  key.params.hmac_params.tag_size = 33
252  for lang in langs:
253    result.append((lang, key))
254
255  key = simple_valid_key()
256  key.params.hmac_params.hash = common_pb2.HashType.SHA512
257  key.params.hmac_params.tag_size = 9
258  for lang in langs:
259    result.append((lang, key))
260
261  key = simple_valid_key()
262  key.params.hmac_params.hash = common_pb2.HashType.SHA512
263  key.params.hmac_params.tag_size = 65
264  for lang in langs:
265    result.append((lang, key))
266
267  key = simple_valid_key()
268  key.params.hmac_params.hash = common_pb2.HashType.SHA224
269  for lang in langs:
270    result.append((lang, key))
271
272  key = simple_valid_key()
273  key.params.hmac_params.hash = common_pb2.HashType.SHA384
274  for lang in langs:
275    result.append((lang, key))
276
277  key = simple_valid_key()
278  key.params.ciphertext_segment_size = 2**31
279  for lang in langs:
280    result.append((lang, key))
281
282  return result
283
284
285class AesCtrHmacStreamingKeyTest(parameterized.TestCase):
286  """Tests specific for keys of type AesCtrHmacStreamingKey.
287
288  See https://developers.google.com/tink/streaming-aead/aes_ctr_hmac_streaming
289  for the documentation.
290  """
291
292  @parameterized.parameters(lang_and_valid_keys_create_only())
293  def test_create_streaming_aead(
294      self, lang: str, key: aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey
295  ):
296    keyset = to_keyset(key)
297    testing_servers.remote_primitive(
298        lang, keyset.SerializeToString(), streaming_aead.StreamingAead
299    )
300
301  @parameterized.parameters(lang_and_valid_keys_create_and_encrypt())
302  def test_create_streaming_aead_encrypt_decrypt(
303      self, lang: str, key: aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey
304  ):
305    keyset = to_keyset(key)
306    saead = testing_servers.remote_primitive(
307        lang, keyset.SerializeToString(), streaming_aead.StreamingAead
308    )
309    plaintext = b'some plaintext'
310    ad = b'associated_data'
311    ciphertext = saead.new_encrypting_stream(
312        io.BytesIO(plaintext), ad
313    ).read()
314    self.assertEqual(
315        saead.new_decrypting_stream(
316            io.BytesIO(ciphertext), ad
317        ).read(),
318        plaintext,
319    )
320
321  @parameterized.parameters(lang_and_invalid_keys())
322  def test_create_streaming_aead_invalid_key_fails(
323      self, lang: str, key: aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey
324  ):
325    keyset = to_keyset(key)
326    with self.assertRaises(tink.TinkError):
327      testing_servers.remote_primitive(
328          lang, keyset.SerializeToString(), streaming_aead.StreamingAead
329      )
330
331  def test_output_prefix_ignored(self):
332    lang_1 = random.choice(
333        tink_config.supported_languages_for_key_type('AesCtrHmacStreamingKey')
334    )
335    lang_2 = random.choice(
336        tink_config.supported_languages_for_key_type('AesCtrHmacStreamingKey')
337    )
338    output_prefix_1 = random.choice(
339        [tink_pb2.RAW, tink_pb2.CRUNCHY, tink_pb2.LEGACY, tink_pb2.TINK]
340    )
341    output_prefix_2 = random.choice(
342        [tink_pb2.RAW, tink_pb2.CRUNCHY, tink_pb2.LEGACY, tink_pb2.TINK]
343    )
344    with self.subTest(
345        f'Testing with languages ({lang_1}, {lang_2}) and output prefix types '
346        f'({tink_pb2.OutputPrefixType.Name(output_prefix_1)}, '
347        f'{tink_pb2.OutputPrefixType.Name(output_prefix_2)})'
348    ):
349      keyset = to_keyset(simple_valid_key())
350      keyset.key[0].output_prefix_type = output_prefix_1
351      saead_1 = testing_servers.remote_primitive(
352          lang_1, keyset.SerializeToString(), streaming_aead.StreamingAead
353      )
354      keyset.key[0].output_prefix_type = output_prefix_2
355      saead_2 = testing_servers.remote_primitive(
356          lang_2, keyset.SerializeToString(), streaming_aead.StreamingAead
357      )
358      plaintext = b'some plaintext'
359      associated_data = b'associated_data'
360      ciphertext = saead_1.new_encrypting_stream(
361          io.BytesIO(plaintext), associated_data
362      ).read()
363      self.assertEqual(
364          saead_2.new_decrypting_stream(
365              io.BytesIO(ciphertext), associated_data
366          ).read(),
367          plaintext,
368      )
369
370  @parameterized.parameters(
371      tink_config.supported_languages_for_key_type('AesCtrHmacStreamingKey')
372  )
373  def test_manually_created_test_vector(self, lang: str):
374    """Tests using a ciphertext created by looking at the documentation.
375
376    See https://developers.google.com/tink/streaming-aead/aes_ctr_hmac_streaming
377    for the documentation. The goal is to ensure that the documentation is
378    clear; we expect readers to read this with the documentation.
379
380    Args:
381      lang: The language to test.
382    """
383
384    def xor(b1: bytes, b2: bytes) -> bytes:
385      return bytes(i ^ j for (i, j) in zip(b1, b2))
386
387    h2b = binascii.a2b_hex
388
389    key = aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingKey(
390        version=0,
391        params=aes_ctr_hmac_streaming_pb2.AesCtrHmacStreamingParams(
392            ciphertext_segment_size=64,
393            derived_key_size=16,
394            hkdf_hash_type=common_pb2.HashType.SHA1,
395            hmac_params=hmac_pb2.HmacParams(
396                hash=common_pb2.HashType.SHA256, tag_size=32
397            ),
398        ),
399        key_value=h2b('6eb56cdc726dfbe5d57f2fcdc6e9345b')
400    )
401    # We set the message to be:
402    msg = b'This is a fairly long plaintext. However, it is not crazy long.'
403    #
404    # We set the associated data to be:
405    aad = b'aad'
406
407    # We picked the header at random: Note the length is 24 = 0x18.
408    header_length = h2b('18')
409    salt = h2b('93b3af5e14ab378d065addfc8484da64')
410    nonce_prefix = h2b('2c0862877baea8')
411    header = header_length + salt + nonce_prefix
412    # hkdf.hkdf_sha1(ikm=key_value, salt=header_salt, info=aad, size=48) gives
413    # '66dd511791296a6cfc94a24041fcab9f' +
414    # '0f736d6e85c448c2c8cc30f094d7e2d89e1a4c6a2dea4e9c8d1d2015e54c609a'
415    # aes_key = h2b('66dd511791296a6cfc94a24041fcab9f')
416    # hmac_key = h2b(
417    #        '0f736d6e85c448c2c8cc30f094d7e2d89e1a4c6a2dea4e9c8d1d2015e54c609a')
418
419    # We next split the message:
420    # len(msg) = 63.
421    # len(M_0) = 8 = CiphertextSegmentSize(64) - Headerlength(24) - TagSize(32)
422    # len(M_1) = 32 = CiphertextSegmentSize(64) - TagSize(32)
423    # len(M_2) = 23 < CiphertextSegmentSize(64) - TagSize(32)
424    msg_0 = msg[:8]
425    msg_1 = msg[8:40]
426    msg_2 = msg[40:]
427
428    # Relevant AES computations with key = 66dd511791296a6cfc94a24041fcab9f
429    #
430    # nonce_prefix + segment_nr + b + i   | Out
431    # -----------------------------------------------------------------------
432    # 2c0862877baea8 00000000 00 00000000 | ea8e18301bd57bfdd2f903025950c827
433    # 2c0862877baea8 00000001 00 00000000 | 2999c8ea5401704243c8cd77929fd526
434    # 2c0862877baea8 00000001 00 00000001 | 17fec5542a842446251bb2f3a81f6249
435    # 2c0862877baea8 00000002 01 00000000 | 70fe58e44835a6602952749e763637d9
436    # 2c0862877baea8 00000002 01 00000001 | d973bca83580867766f38b056d735902
437    #
438    c0 = xor(msg_0, h2b(b'ea8e18301bd57bfd'))
439    c1 = xor(msg_1[:16], h2b('2999c8ea5401704243c8cd77929fd526')) + xor(
440        msg_1[16:32], h2b('17fec5542a842446251bb2f3a81f6249')
441    )
442    c2 = xor(msg_2[:16], h2b('70fe58e44835a6602952749e763637d9')) + xor(
443        msg_2[16:], h2b('d973bca8358086')
444    )
445
446    # T0 = hmac(key = hmac_key, h2b('2c0862877baea8000000000000000000' + c0)
447    t0 = h2b('8303ca71c04d8e06e1b01cff7c1178af47dac031517b1f6a2d9be84105677a68')
448    # T1 = hmac(key = hmac_key, h2b('2c0862877baea8000000010000000000' + c1)
449    t1 = h2b('834d890839f37f762caddc029cc673300ff107fd51f9a62058fcd00befc362e5')
450    # T2 = hmac(key = hmac_key, h2b('2c0862877baea8000000020100000000' + c2)
451    t2 = h2b('5fb0c893903271af38380c2f355cb85e5ec571648513123321bde0c6042f43c7')
452
453    ciphertext = header + c0 + t0 + c1 + t1 + c2 + t2
454
455    keyset = to_keyset(key)
456    saead = testing_servers.remote_primitive(
457        lang, keyset.SerializeToString(), streaming_aead.StreamingAead
458    )
459
460    self.assertEqual(
461        saead.new_decrypting_stream(io.BytesIO(ciphertext), aad).read(),
462        msg,
463    )
464
465if __name__ == '__main__':
466  absltest.main()
467