xref: /aosp_15_r20/external/tink/go/prf/hkdf_prf_key_manager_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package prf_test
18
19import (
20	"bytes"
21	"encoding/hex"
22	"fmt"
23	"testing"
24
25	"github.com/google/go-cmp/cmp"
26	"google.golang.org/protobuf/proto"
27	"github.com/google/tink/go/core/registry"
28	"github.com/google/tink/go/internal/internalregistry"
29	"github.com/google/tink/go/prf"
30	"github.com/google/tink/go/prf/subtle"
31	"github.com/google/tink/go/subtle/random"
32	"github.com/google/tink/go/testutil"
33	commonpb "github.com/google/tink/go/proto/common_go_proto"
34	hkdfpb "github.com/google/tink/go/proto/hkdf_prf_go_proto"
35	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
36)
37
38func TestGetPrimitiveHKDFBasic(t *testing.T) {
39	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
40	if err != nil {
41		t.Errorf("HKDF PRF key manager not found: %s", err)
42	}
43	testKeys := genValidHKDFKeys()
44	for i := 0; i < len(testKeys); i++ {
45		serializedKey, _ := proto.Marshal(testKeys[i])
46		p, err := km.Primitive(serializedKey)
47		if err != nil {
48			t.Errorf("unexpected error in test case %d: %s", i, err)
49		}
50		if err := validateHKDFPrimitive(p, testKeys[i]); err != nil {
51			t.Errorf("%s", err)
52		}
53	}
54}
55
56func TestGetPrimitiveHKDFWithInvalidInput(t *testing.T) {
57	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
58	if err != nil {
59		t.Errorf("cannot obtain HKDF PRF key manager: %s", err)
60	}
61	// invalid key
62	testKeys := genInvalidHKDFKeys()
63	for i := 0; i < len(testKeys); i++ {
64		serializedKey, _ := proto.Marshal(testKeys[i])
65		if _, err := km.Primitive(serializedKey); err == nil {
66			t.Errorf("expect an error in test case %d", i)
67		}
68	}
69	if _, err := km.Primitive(nil); err == nil {
70		t.Errorf("expect an error when input is nil")
71	}
72	// empty input
73	if _, err := km.Primitive([]byte{}); err == nil {
74		t.Errorf("expect an error when input is empty")
75	}
76}
77
78func TestNewKeyHKDFMultipleTimes(t *testing.T) {
79	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
80	if err != nil {
81		t.Errorf("cannot obtain HKDF PRF key manager: %s", err)
82	}
83	serializedFormat, _ := proto.Marshal(testutil.NewHKDFPRFKeyFormat(commonpb.HashType_SHA256, make([]byte, 0)))
84	keys := make(map[string]bool)
85	nTest := 26
86	for i := 0; i < nTest; i++ {
87		key, _ := km.NewKey(serializedFormat)
88		serializedKey, _ := proto.Marshal(key)
89		keys[string(serializedKey)] = true
90
91		keyData, _ := km.NewKeyData(serializedFormat)
92		serializedKey = keyData.Value
93		keys[string(serializedKey)] = true
94	}
95	if len(keys) != nTest*2 {
96		t.Errorf("key is repeated")
97	}
98}
99
100func TestNewKeyHKDFBasic(t *testing.T) {
101	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
102	if err != nil {
103		t.Errorf("cannot obtain HKDF PRF key manager: %s", err)
104	}
105	testFormats := genValidHKDFKeyFormats()
106	for i := 0; i < len(testFormats); i++ {
107		serializedFormat, _ := proto.Marshal(testFormats[i])
108		key, err := km.NewKey(serializedFormat)
109		if err != nil {
110			t.Errorf("unexpected error in test case %d: %s", i, err)
111		}
112		if err := validateHKDFKey(testFormats[i], key.(*hkdfpb.HkdfPrfKey)); err != nil {
113			t.Errorf("%s", err)
114		}
115	}
116}
117
118func TestNewKeyHKDFWithInvalidInput(t *testing.T) {
119	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
120	if err != nil {
121		t.Errorf("cannot obtain HKDF PRF key manager: %s", err)
122	}
123	// invalid key formats
124	testFormats := genInvalidHKDFKeyFormats()
125	for i := 0; i < len(testFormats); i++ {
126		serializedFormat, err := proto.Marshal(testFormats[i])
127		if err != nil {
128			fmt.Println("Error!")
129		}
130		if _, err := km.NewKey(serializedFormat); err == nil {
131			t.Errorf("expect an error in test case %d: %s", i, err)
132		}
133	}
134	if _, err := km.NewKey(nil); err == nil {
135		t.Errorf("expect an error when input is nil")
136	}
137	// empty input
138	if _, err := km.NewKey([]byte{}); err == nil {
139		t.Errorf("expect an error when input is empty")
140	}
141}
142
143func TestNewKeyDataHKDFBasic(t *testing.T) {
144	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
145	if err != nil {
146		t.Errorf("cannot obtain HKDF PRF key manager: %s", err)
147	}
148	testFormats := genValidHKDFKeyFormats()
149	for i := 0; i < len(testFormats); i++ {
150		serializedFormat, _ := proto.Marshal(testFormats[i])
151		keyData, err := km.NewKeyData(serializedFormat)
152		if err != nil {
153			t.Errorf("unexpected error in test case %d: %s", i, err)
154		}
155		if keyData.TypeUrl != testutil.HKDFPRFTypeURL {
156			t.Errorf("incorrect type url in test case %d", i)
157		}
158		if keyData.KeyMaterialType != tinkpb.KeyData_SYMMETRIC {
159			t.Errorf("incorrect key material type in test case %d", i)
160		}
161		key := new(hkdfpb.HkdfPrfKey)
162		if err := proto.Unmarshal(keyData.Value, key); err != nil {
163			t.Errorf("invalid key value")
164		}
165		if err := validateHKDFKey(testFormats[i], key); err != nil {
166			t.Errorf("invalid key")
167		}
168	}
169}
170
171func TestNewKeyDataHKDFWithInvalidInput(t *testing.T) {
172	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
173	if err != nil {
174		t.Errorf("HKDF PRF key manager not found: %s", err)
175	}
176	// invalid key formats
177	testFormats := genInvalidHKDFKeyFormats()
178	for i := 0; i < len(testFormats); i++ {
179		serializedFormat, _ := proto.Marshal(testFormats[i])
180		if _, err := km.NewKeyData(serializedFormat); err == nil {
181			t.Errorf("expect an error in test case %d", i)
182		}
183	}
184	// nil input
185	if _, err := km.NewKeyData(nil); err == nil {
186		t.Errorf("expect an error when input is nil")
187	}
188}
189
190func TestHKDFDoesSupport(t *testing.T) {
191	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
192	if err != nil {
193		t.Errorf("HKDF PRF key manager not found: %s", err)
194	}
195	if !km.DoesSupport(testutil.HKDFPRFTypeURL) {
196		t.Errorf("HKDFPRFKeyManager must support %s", testutil.HKDFPRFTypeURL)
197	}
198	if km.DoesSupport("some bad type") {
199		t.Errorf("HKDFPRFKeyManager must support only %s", testutil.HKDFPRFTypeURL)
200	}
201}
202
203func TestHKDFTypeURL(t *testing.T) {
204	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
205	if err != nil {
206		t.Errorf("HKDF PRF key manager not found: %s", err)
207	}
208	if km.TypeURL() != testutil.HKDFPRFTypeURL {
209		t.Errorf("incorrect GetKeyType()")
210	}
211}
212
213func TestHKDFKeyMaterialType(t *testing.T) {
214	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
215	if err != nil {
216		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HKDFPRFTypeURL, err)
217	}
218	keyManager, ok := km.(internalregistry.DerivableKeyManager)
219	if !ok {
220		t.Fatalf("key manager is not DerivableKeyManager")
221	}
222	if got, want := keyManager.KeyMaterialType(), tinkpb.KeyData_SYMMETRIC; got != want {
223		t.Errorf("KeyMaterialType() = %v, want %v", got, want)
224	}
225}
226
227func TestHKDFDeriveKey(t *testing.T) {
228	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
229	if err != nil {
230		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HKDFPRFTypeURL, err)
231	}
232	keyManager, ok := km.(internalregistry.DerivableKeyManager)
233	if !ok {
234		t.Fatalf("key manager is not DerivableKeyManager")
235	}
236
237	var keySize uint32 = 32
238	for _, test := range []struct {
239		name     string
240		hashType commonpb.HashType
241		salt     []byte
242	}{
243		{
244			name:     "SHA256",
245			hashType: commonpb.HashType_SHA256,
246			salt:     make([]byte, 0),
247		},
248		{
249			name:     "SHA256/salt",
250			hashType: commonpb.HashType_SHA256,
251			salt:     []byte{0x01, 0x03, 0x42},
252		},
253		{
254			name:     "SHA512",
255			hashType: commonpb.HashType_SHA512,
256			salt:     make([]byte, 0),
257		},
258		{
259			name:     "SHA512/salt",
260			hashType: commonpb.HashType_SHA512,
261			salt:     []byte{0x01, 0x03, 0x42},
262		},
263	} {
264		t.Run(test.name, func(t *testing.T) {
265			keyFormat := testutil.NewHKDFPRFKeyFormat(test.hashType, test.salt)
266			serializedKeyFormat, err := proto.Marshal(keyFormat)
267			if err != nil {
268				t.Fatalf("proto.Marshal(%v) err = %v, want nil", keyFormat, err)
269			}
270
271			rand := random.GetRandomBytes(keySize)
272			buf := &bytes.Buffer{}
273			buf.Write(rand) // never returns a non-nil error
274
275			k, err := keyManager.DeriveKey(serializedKeyFormat, buf)
276			if err != nil {
277				t.Fatalf("keyManager.DeriveKey() err = %v, want nil", err)
278			}
279			key := k.(*hkdfpb.HkdfPrfKey)
280			if got, want := len(key.GetKeyValue()), int(keySize); got != want {
281				t.Errorf("key length = %d, want %d", got, want)
282			}
283			if diff := cmp.Diff(key.GetKeyValue(), rand); diff != "" {
284				t.Errorf("incorrect derived key: diff = %v", diff)
285			}
286		})
287	}
288}
289
290func TestHKDFDeriveKeyFailsWithInvalidKeyFormats(t *testing.T) {
291	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
292	if err != nil {
293		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HKDFPRFTypeURL, err)
294	}
295	keyManager, ok := km.(internalregistry.DerivableKeyManager)
296	if !ok {
297		t.Fatalf("key manager is not DerivableKeyManager")
298	}
299
300	var keySize uint32 = 32
301	validKeyFormat := &hkdfpb.HkdfPrfKeyFormat{
302		Params:  testutil.NewHKDFPRFParams(commonpb.HashType_SHA256, make([]byte, 0)),
303		KeySize: keySize,
304		Version: 0,
305	}
306	serializedValidKeyFormat, err := proto.Marshal(validKeyFormat)
307	if err != nil {
308		t.Fatalf("proto.Marshal(%v) err = %v, want nil", validKeyFormat, err)
309	}
310	buf := bytes.NewBuffer(random.GetRandomBytes(keySize))
311	if _, err := keyManager.DeriveKey(serializedValidKeyFormat, buf); err != nil {
312		t.Fatalf("keyManager.DeriveKey() err = %v, want nil", err)
313	}
314
315	for _, test := range []struct {
316		name      string
317		keyFormat *hkdfpb.HkdfPrfKeyFormat
318		randLen   uint32
319	}{
320		{
321			name: "invalid key size",
322			keyFormat: &hkdfpb.HkdfPrfKeyFormat{
323				Params:  validKeyFormat.GetParams(),
324				KeySize: 16,
325				Version: validKeyFormat.GetVersion(),
326			},
327			randLen: keySize,
328		},
329		{
330			name:      "not enough randomness",
331			keyFormat: validKeyFormat,
332			randLen:   16,
333		},
334		{
335			name: "invalid version",
336			keyFormat: &hkdfpb.HkdfPrfKeyFormat{
337				Params:  validKeyFormat.GetParams(),
338				KeySize: validKeyFormat.GetKeySize(),
339				Version: 100000,
340			},
341			randLen: keySize,
342		},
343		{
344			name:      "empty key format",
345			keyFormat: &hkdfpb.HkdfPrfKeyFormat{},
346			randLen:   keySize,
347		},
348		{
349			name:    "nil key format",
350			randLen: keySize,
351		},
352	} {
353		t.Run(test.name, func(t *testing.T) {
354			serializedKeyFormat, err := proto.Marshal(test.keyFormat)
355			if err != nil {
356				t.Fatalf("proto.Marshal(%v) err = %v, want nil", test.keyFormat, err)
357			}
358			buf := bytes.NewBuffer(random.GetRandomBytes(test.randLen))
359			if _, err := keyManager.DeriveKey(serializedKeyFormat, buf); err == nil {
360				t.Errorf("keyManager.DeriveKey() err = nil, want non-nil")
361			}
362		})
363	}
364}
365
366func TestHKDFDeriveKeyFailsWithMalformedSerializedKeyFormat(t *testing.T) {
367	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
368	if err != nil {
369		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HKDFPRFTypeURL, err)
370	}
371	keyManager, ok := km.(internalregistry.DerivableKeyManager)
372	if !ok {
373		t.Fatalf("key manager is not DerivableKeyManager")
374	}
375
376	var keySize uint32 = 32
377	malformedSerializedKeyFormat := random.GetRandomBytes(
378		uint32(
379			proto.Size(&hkdfpb.HkdfPrfKeyFormat{
380				Params:  testutil.NewHKDFPRFParams(commonpb.HashType_SHA256, make([]byte, 0)),
381				KeySize: keySize,
382				Version: 0,
383			})))
384
385	buf := bytes.NewBuffer(random.GetRandomBytes(keySize))
386	if _, err := keyManager.DeriveKey(malformedSerializedKeyFormat, buf); err == nil {
387		t.Errorf("keyManager.DeriveKey() err = nil, want non-nil")
388	}
389}
390
391func TestAESGCMDeriveKeyFailsWithInsufficientRandomness(t *testing.T) {
392	km, err := registry.GetKeyManager(testutil.HKDFPRFTypeURL)
393	if err != nil {
394		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HKDFPRFTypeURL, err)
395	}
396	keyManager, ok := km.(internalregistry.DerivableKeyManager)
397	if !ok {
398		t.Fatalf("key manager is not DerivableKeyManager")
399	}
400	keyFormat, err := proto.Marshal(testutil.NewHKDFPRFKeyFormat(commonpb.HashType_SHA256, []byte("salty")))
401	if err != nil {
402		t.Fatalf("proto.Marshal() err = %v, want nil", err)
403	}
404	var keySize uint32 = 32
405	{
406		buf := bytes.NewBuffer(random.GetRandomBytes(keySize))
407		if _, err := keyManager.DeriveKey(keyFormat, buf); err != nil {
408			t.Errorf("keyManager.DeriveKey() err = %v, want nil", err)
409		}
410	}
411	{
412		insufficientBuf := bytes.NewBuffer(random.GetRandomBytes(keySize - 1))
413		if _, err := keyManager.DeriveKey(keyFormat, insufficientBuf); err == nil {
414			t.Errorf("keyManager.DeriveKey() err = nil, want non-nil")
415		}
416	}
417}
418
419func genInvalidHKDFKeys() []proto.Message {
420	badVersionKey := testutil.NewHKDFPRFKey(commonpb.HashType_SHA256, make([]byte, 0))
421	badVersionKey.Version++
422	shortKey := testutil.NewHKDFPRFKey(commonpb.HashType_SHA256, make([]byte, 0))
423	shortKey.KeyValue = []byte{1, 1}
424	return []proto.Message{
425		// not a HKDFPRFKey
426		testutil.NewHKDFPRFParams(commonpb.HashType_SHA256, make([]byte, 0)),
427		// bad version
428		badVersionKey,
429		// key too short
430		shortKey,
431		// SHA-1
432		testutil.NewHKDFPRFKey(commonpb.HashType_SHA1, make([]byte, 0)),
433		// unknown hash type
434		testutil.NewHKDFPRFKey(commonpb.HashType_UNKNOWN_HASH, make([]byte, 0)),
435	}
436}
437
438func genInvalidHKDFKeyFormats() []proto.Message {
439	shortKeyFormat := testutil.NewHKDFPRFKeyFormat(commonpb.HashType_SHA256, make([]byte, 0))
440	shortKeyFormat.KeySize = 1
441	return []proto.Message{
442		// not a HKDFPRFKeyFormat
443		testutil.NewHMACParams(commonpb.HashType_SHA256, 32),
444		// key too short
445		shortKeyFormat,
446		// SHA-1
447		testutil.NewHKDFPRFKeyFormat(commonpb.HashType_SHA1, make([]byte, 0)),
448		// unknown hash type
449		testutil.NewHKDFPRFKeyFormat(commonpb.HashType_UNKNOWN_HASH, make([]byte, 0)),
450	}
451}
452
453func genValidHKDFKeyFormats() []*hkdfpb.HkdfPrfKeyFormat {
454	return []*hkdfpb.HkdfPrfKeyFormat{
455		testutil.NewHKDFPRFKeyFormat(commonpb.HashType_SHA256, make([]byte, 0)),
456		testutil.NewHKDFPRFKeyFormat(commonpb.HashType_SHA512, make([]byte, 0)),
457		testutil.NewHKDFPRFKeyFormat(commonpb.HashType_SHA256, []byte{0x01, 0x03, 0x42}),
458		testutil.NewHKDFPRFKeyFormat(commonpb.HashType_SHA512, []byte{0x01, 0x03, 0x42}),
459	}
460}
461
462func genValidHKDFKeys() []*hkdfpb.HkdfPrfKey {
463	return []*hkdfpb.HkdfPrfKey{
464		testutil.NewHKDFPRFKey(commonpb.HashType_SHA256, make([]byte, 0)),
465		testutil.NewHKDFPRFKey(commonpb.HashType_SHA512, make([]byte, 0)),
466		testutil.NewHKDFPRFKey(commonpb.HashType_SHA256, []byte{0x01, 0x03, 0x42}),
467		testutil.NewHKDFPRFKey(commonpb.HashType_SHA512, []byte{0x01, 0x03, 0x42}),
468	}
469}
470
471// Checks whether the given HKDFPRFKey matches the given key HKDFPRFKeyFormat
472func validateHKDFKey(format *hkdfpb.HkdfPrfKeyFormat, key *hkdfpb.HkdfPrfKey) error {
473	if format.KeySize != uint32(len(key.KeyValue)) ||
474		key.Params.Hash != format.Params.Hash {
475		return fmt.Errorf("key format and generated key do not match")
476	}
477	p, err := subtle.NewHKDFPRF(commonpb.HashType_name[int32(key.Params.Hash)], key.KeyValue, key.Params.Salt)
478	if err != nil {
479		return fmt.Errorf("cannot create primitive from key: %s", err)
480	}
481	return validateHKDFPrimitive(p, key)
482}
483
484// validateHKDFPrimitive checks whether the given primitive matches the given HKDFPRFKey
485func validateHKDFPrimitive(p interface{}, key *hkdfpb.HkdfPrfKey) error {
486	hkdfPrimitive := p.(prf.PRF)
487	prfPrimitive, err := subtle.NewHKDFPRF(commonpb.HashType_name[int32(key.Params.Hash)], key.KeyValue, key.Params.Salt)
488	if err != nil {
489		return fmt.Errorf("Could not create HKDF PRF with key material %q: %s", hex.EncodeToString(key.KeyValue), err)
490	}
491	data := random.GetRandomBytes(20)
492	res, err := hkdfPrimitive.ComputePRF(data, 16)
493	if err != nil {
494		return fmt.Errorf("prf computation failed: %s", err)
495	}
496	if len(res) != 16 {
497		return fmt.Errorf("prf computation did not produce 16 byte output")
498	}
499	res2, err := prfPrimitive.ComputePRF(data, 16)
500	if err != nil {
501		return fmt.Errorf("prf computation failed: %s", err)
502	}
503	if len(res2) != 16 {
504		return fmt.Errorf("prf computation did not produce 16 byte output")
505	}
506	if hex.EncodeToString(res) != hex.EncodeToString(res2) {
507		return fmt.Errorf("prf computation did not produce the same output for the same key and input")
508	}
509	return nil
510}
511