xref: /aosp_15_r20/external/tink/go/prf/hmac_prf_key_manager_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package prf_test
18
19import (
20	"bytes"
21	"encoding/hex"
22	"fmt"
23	"testing"
24
25	"github.com/google/go-cmp/cmp"
26	"google.golang.org/protobuf/proto"
27	"github.com/google/tink/go/core/registry"
28	"github.com/google/tink/go/internal/internalregistry"
29	"github.com/google/tink/go/prf"
30	"github.com/google/tink/go/prf/subtle"
31	"github.com/google/tink/go/subtle/random"
32	"github.com/google/tink/go/testutil"
33	commonpb "github.com/google/tink/go/proto/common_go_proto"
34	hmacpb "github.com/google/tink/go/proto/hmac_prf_go_proto"
35	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
36)
37
38func TestGetPrimitiveHMACBasic(t *testing.T) {
39	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
40	if err != nil {
41		t.Errorf("HMAC PRF key manager not found: %s", err)
42	}
43	testKeys := genValidHMACPRFKeys()
44	for i := 0; i < len(testKeys); i++ {
45		serializedKey, _ := proto.Marshal(testKeys[i])
46		p, err := km.Primitive(serializedKey)
47		if err != nil {
48			t.Errorf("unexpected error in test case %d: %s", i, err)
49		}
50		if err := validateHMACPRFPrimitive(p, testKeys[i]); err != nil {
51			t.Errorf("%s", err)
52		}
53	}
54}
55
56func TestGetPrimitiveHMACWithInvalidInput(t *testing.T) {
57	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
58	if err != nil {
59		t.Errorf("cannot obtain HMAC PRFkey manager: %s", err)
60	}
61	// invalid key
62	testKeys := genInvalidHMACPRFKeys()
63	for i := 0; i < len(testKeys); i++ {
64		serializedKey, _ := proto.Marshal(testKeys[i])
65		if _, err := km.Primitive(serializedKey); err == nil {
66			t.Errorf("expect an error in test case %d", i)
67		}
68	}
69	if _, err := km.Primitive(nil); err == nil {
70		t.Errorf("expect an error when input is nil")
71	}
72	// empty input
73	if _, err := km.Primitive([]byte{}); err == nil {
74		t.Errorf("expect an error when input is empty")
75	}
76}
77
78func TestNewKeyHMACMultipleTimes(t *testing.T) {
79	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
80	if err != nil {
81		t.Errorf("cannot obtain HMAC PRF key manager: %s", err)
82	}
83	serializedFormat, _ := proto.Marshal(testutil.NewHMACPRFKeyFormat(commonpb.HashType_SHA256))
84	keys := make(map[string]bool)
85	nTest := 26
86	for i := 0; i < nTest; i++ {
87		key, _ := km.NewKey(serializedFormat)
88		serializedKey, _ := proto.Marshal(key)
89		keys[string(serializedKey)] = true
90
91		keyData, _ := km.NewKeyData(serializedFormat)
92		serializedKey = keyData.Value
93		keys[string(serializedKey)] = true
94	}
95	if len(keys) != nTest*2 {
96		t.Errorf("key is repeated")
97	}
98}
99
100func TestNewKeyHMACBasic(t *testing.T) {
101	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
102	if err != nil {
103		t.Errorf("cannot obtain HMAC PRF key manager: %s", err)
104	}
105	testFormats := genValidHMACPRFKeyFormats()
106	for i := 0; i < len(testFormats); i++ {
107		serializedFormat, _ := proto.Marshal(testFormats[i])
108		key, err := km.NewKey(serializedFormat)
109		if err != nil {
110			t.Errorf("unexpected error in test case %d: %s", i, err)
111		}
112		if err := validateHMACPRFKey(testFormats[i], key.(*hmacpb.HmacPrfKey)); err != nil {
113			t.Errorf("%s", err)
114		}
115	}
116}
117
118func TestNewKeyHMACWithInvalidInput(t *testing.T) {
119	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
120	if err != nil {
121		t.Errorf("cannot obtain HMAC PRF key manager: %s", err)
122	}
123	// invalid key formats
124	testFormats := genInvalidHMACPRFKeyFormats()
125	for i := 0; i < len(testFormats); i++ {
126		serializedFormat, err := proto.Marshal(testFormats[i])
127		if err != nil {
128			fmt.Println("Error!")
129		}
130		if _, err := km.NewKey(serializedFormat); err == nil {
131			t.Errorf("expect an error in test case %d: %s", i, err)
132		}
133	}
134	if _, err := km.NewKey(nil); err == nil {
135		t.Errorf("expect an error when input is nil")
136	}
137	// empty input
138	if _, err := km.NewKey([]byte{}); err == nil {
139		t.Errorf("expect an error when input is empty")
140	}
141}
142
143func TestNewKeyDataHMACBasic(t *testing.T) {
144	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
145	if err != nil {
146		t.Errorf("cannot obtain HMAC PRF key manager: %s", err)
147	}
148	testFormats := genValidHMACPRFKeyFormats()
149	for i := 0; i < len(testFormats); i++ {
150		serializedFormat, _ := proto.Marshal(testFormats[i])
151		keyData, err := km.NewKeyData(serializedFormat)
152		if err != nil {
153			t.Errorf("unexpected error in test case %d: %s", i, err)
154		}
155		if keyData.TypeUrl != testutil.HMACPRFTypeURL {
156			t.Errorf("incorrect type url in test case %d", i)
157		}
158		if keyData.KeyMaterialType != tinkpb.KeyData_SYMMETRIC {
159			t.Errorf("incorrect key material type in test case %d", i)
160		}
161		key := new(hmacpb.HmacPrfKey)
162		if err := proto.Unmarshal(keyData.Value, key); err != nil {
163			t.Errorf("invalid key value")
164		}
165		if err := validateHMACPRFKey(testFormats[i], key); err != nil {
166			t.Errorf("invalid key")
167		}
168	}
169}
170
171func TestNewKeyDataHMACWithInvalidInput(t *testing.T) {
172	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
173	if err != nil {
174		t.Errorf("HMAC PRF key manager not found: %s", err)
175	}
176	// invalid key formats
177	testFormats := genInvalidHMACPRFKeyFormats()
178	for i := 0; i < len(testFormats); i++ {
179		serializedFormat, _ := proto.Marshal(testFormats[i])
180		if _, err := km.NewKeyData(serializedFormat); err == nil {
181			t.Errorf("expect an error in test case %d", i)
182		}
183	}
184	// nil input
185	if _, err := km.NewKeyData(nil); err == nil {
186		t.Errorf("expect an error when input is nil")
187	}
188}
189
190func TestHMACDoesSupport(t *testing.T) {
191	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
192	if err != nil {
193		t.Errorf("HMAC PRF key manager not found: %s", err)
194	}
195	if !km.DoesSupport(testutil.HMACPRFTypeURL) {
196		t.Errorf("HMACPRFKeyManager must support %s", testutil.HMACPRFTypeURL)
197	}
198	if km.DoesSupport("some bad type") {
199		t.Errorf("HMACPRFKeyManager must support only %s", testutil.HMACPRFTypeURL)
200	}
201}
202
203func TestHMACTypeURL(t *testing.T) {
204	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
205	if err != nil {
206		t.Errorf("HMAC PRF key manager not found: %s", err)
207	}
208	if km.TypeURL() != testutil.HMACPRFTypeURL {
209		t.Errorf("incorrect GetKeyType()")
210	}
211}
212
213func TestHMACKeyMaterialType(t *testing.T) {
214	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
215	if err != nil {
216		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HMACPRFTypeURL, err)
217	}
218	keyManager, ok := km.(internalregistry.DerivableKeyManager)
219	if !ok {
220		t.Fatalf("key manager is not DerivableKeyManager")
221	}
222	if got, want := keyManager.KeyMaterialType(), tinkpb.KeyData_SYMMETRIC; got != want {
223		t.Errorf("KeyMaterialType() = %v, want %v", got, want)
224	}
225}
226
227func TestHMACDeriveKey(t *testing.T) {
228	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
229	if err != nil {
230		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HMACPRFTypeURL, err)
231	}
232	keyManager, ok := km.(internalregistry.DerivableKeyManager)
233	if !ok {
234		t.Fatalf("key manager is not DerivableKeyManager")
235	}
236	keyFormat, err := proto.Marshal(&hmacpb.HmacPrfKeyFormat{
237		Version: testutil.HMACPRFKeyVersion,
238		KeySize: 16,
239		Params:  &hmacpb.HmacPrfParams{Hash: commonpb.HashType_SHA256},
240	})
241	if err != nil {
242		t.Fatalf("proto.Marshal() err = %v, want nil", err)
243	}
244	rand := random.GetRandomBytes(16)
245	buf := &bytes.Buffer{}
246	buf.Write(rand) // Never returns a non-nil error.
247	k, err := keyManager.DeriveKey(keyFormat, buf)
248	if err != nil {
249		t.Fatalf("keyManager.DeriveKey() err = %v, want nil", err)
250	}
251	key := k.(*hmacpb.HmacPrfKey)
252	if got, want := len(key.GetKeyValue()), 16; got != want {
253		t.Errorf("key length = %d, want %d", got, want)
254	}
255	if diff := cmp.Diff(key.GetKeyValue(), rand); diff != "" {
256		t.Errorf("incorrect derived key: diff = %v", diff)
257	}
258}
259
260func TestHMACDeriveKeyFailsWithInvalidKeyFormats(t *testing.T) {
261	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
262	if err != nil {
263		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HMACPRFTypeURL, err)
264	}
265	keyManager, ok := km.(internalregistry.DerivableKeyManager)
266	if !ok {
267		t.Fatalf("key manager is not DerivableKeyManager")
268	}
269
270	validKeyFormat := &hmacpb.HmacPrfKeyFormat{
271		Version: testutil.HMACPRFKeyVersion,
272		KeySize: 16,
273		Params:  &hmacpb.HmacPrfParams{Hash: commonpb.HashType_SHA256},
274	}
275	serializedValidKeyFormat, err := proto.Marshal(validKeyFormat)
276	if err != nil {
277		t.Fatalf("proto.Marshal(%v) err = %v, want nil", validKeyFormat, err)
278	}
279	buf := bytes.NewBuffer(random.GetRandomBytes(validKeyFormat.KeySize))
280	if _, err := keyManager.DeriveKey(serializedValidKeyFormat, buf); err != nil {
281		t.Fatalf("keyManager.DeriveKey() err = %v, want nil", err)
282	}
283
284	for _, test := range []struct {
285		name    string
286		version uint32
287		keySize uint32
288		hash    commonpb.HashType
289	}{
290		{
291			name:    "invalid version",
292			version: 10,
293			keySize: validKeyFormat.KeySize,
294			hash:    validKeyFormat.Params.Hash,
295		},
296		{
297			name:    "invalid key size",
298			version: validKeyFormat.Version,
299			keySize: 10,
300			hash:    validKeyFormat.Params.Hash,
301		},
302		{
303			name:    "invalid hash",
304			version: validKeyFormat.Version,
305			keySize: validKeyFormat.KeySize,
306			hash:    commonpb.HashType_UNKNOWN_HASH,
307		},
308	} {
309		t.Run(test.name, func(t *testing.T) {
310			keyFormat, err := proto.Marshal(&hmacpb.HmacPrfKeyFormat{
311				Version: test.version,
312				KeySize: test.keySize,
313				Params:  &hmacpb.HmacPrfParams{Hash: test.hash},
314			})
315			if err != nil {
316				t.Fatalf("proto.Marshal() err = %v, want nil", err)
317			}
318			buf := bytes.NewBuffer(random.GetRandomBytes(test.keySize))
319			if _, err := keyManager.DeriveKey(keyFormat, buf); err == nil {
320				t.Errorf("keyManager.DeriveKey() err = nil, want non-nil")
321			}
322		})
323	}
324}
325
326func TestHMACDeriveKeyFailsWithMalformedKeyFormats(t *testing.T) {
327	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
328	if err != nil {
329		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HMACPRFTypeURL, err)
330	}
331	keyManager, ok := km.(internalregistry.DerivableKeyManager)
332	if !ok {
333		t.Fatalf("key manager is not DerivableKeyManager")
334	}
335	// Proto messages start with a VarInt, which always ends with a byte with the
336	// MSB unset, so 0x80 is invalid.
337	invalidSerialization, err := hex.DecodeString("80")
338	if err != nil {
339		t.Errorf("hex.DecodeString() err = %v, want nil", err)
340	}
341	for _, test := range []struct {
342		name      string
343		keyFormat []byte
344	}{
345		{
346			name:      "nil",
347			keyFormat: nil,
348		},
349		{
350			name:      "empty",
351			keyFormat: []byte{},
352		},
353		{
354			name:      "invalid serialization",
355			keyFormat: invalidSerialization,
356		},
357	} {
358		t.Run(test.name, func(t *testing.T) {
359			buf := bytes.NewBuffer(random.GetRandomBytes(16))
360			if _, err := keyManager.DeriveKey(test.keyFormat, buf); err == nil {
361				t.Errorf("keyManager.DeriveKey() err = nil, want non-nil")
362			}
363		})
364	}
365}
366
367func TestHMACDeriveKeyFailsWithInsufficientRandomness(t *testing.T) {
368	km, err := registry.GetKeyManager(testutil.HMACPRFTypeURL)
369	if err != nil {
370		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.HMACPRFTypeURL, err)
371	}
372	keyManager, ok := km.(internalregistry.DerivableKeyManager)
373	if !ok {
374		t.Fatalf("key manager is not DerivableKeyManager")
375	}
376	keyFormat, err := proto.Marshal(&hmacpb.HmacPrfKeyFormat{
377		Version: testutil.HMACPRFKeyVersion,
378		KeySize: 16,
379		Params:  &hmacpb.HmacPrfParams{Hash: commonpb.HashType_SHA256},
380	})
381	if err != nil {
382		t.Fatalf("proto.Marshal(%v) err = %v, want nil", keyFormat, err)
383	}
384	{
385		buf := bytes.NewBuffer(random.GetRandomBytes(16))
386		if _, err := keyManager.DeriveKey(keyFormat, buf); err != nil {
387			t.Errorf("keyManager.DeriveKey() err = %v, want nil", err)
388		}
389	}
390	{
391		insufficientBuf := bytes.NewBuffer(random.GetRandomBytes(15))
392		if _, err := keyManager.DeriveKey(keyFormat, insufficientBuf); err == nil {
393			t.Errorf("keyManager.DeriveKey() err = nil, want non-nil")
394		}
395	}
396}
397
398func genInvalidHMACPRFKeys() []proto.Message {
399	badVersionKey := testutil.NewHMACPRFKey(commonpb.HashType_SHA256)
400	badVersionKey.Version++
401	shortKey := testutil.NewHMACPRFKey(commonpb.HashType_SHA256)
402	shortKey.KeyValue = []byte{1, 1}
403	return []proto.Message{
404		// not a HMACPRFKey
405		testutil.NewHMACParams(commonpb.HashType_SHA256, 32),
406		// bad version
407		badVersionKey,
408		// key too short
409		shortKey,
410		// unknown hash type
411		testutil.NewHMACPRFKey(commonpb.HashType_UNKNOWN_HASH),
412	}
413}
414
415func genInvalidHMACPRFKeyFormats() []proto.Message {
416	shortKeyFormat := testutil.NewHMACPRFKeyFormat(commonpb.HashType_SHA256)
417	shortKeyFormat.KeySize = 1
418	return []proto.Message{
419		// not a HMACPRFKeyFormat
420		testutil.NewHMACParams(commonpb.HashType_SHA256, 32),
421		// key too short
422		shortKeyFormat,
423		// unknown hash type
424		testutil.NewHMACPRFKeyFormat(commonpb.HashType_UNKNOWN_HASH),
425	}
426}
427
428func genValidHMACPRFKeyFormats() []*hmacpb.HmacPrfKeyFormat {
429	return []*hmacpb.HmacPrfKeyFormat{
430		testutil.NewHMACPRFKeyFormat(commonpb.HashType_SHA1),
431		testutil.NewHMACPRFKeyFormat(commonpb.HashType_SHA256),
432		testutil.NewHMACPRFKeyFormat(commonpb.HashType_SHA512),
433	}
434}
435
436func genValidHMACPRFKeys() []*hmacpb.HmacPrfKey {
437	return []*hmacpb.HmacPrfKey{
438		testutil.NewHMACPRFKey(commonpb.HashType_SHA1),
439		testutil.NewHMACPRFKey(commonpb.HashType_SHA256),
440		testutil.NewHMACPRFKey(commonpb.HashType_SHA512),
441	}
442}
443
444// Checks whether the given HMACPRFKey matches the given key HMACPRFKeyFormat
445func validateHMACPRFKey(format *hmacpb.HmacPrfKeyFormat, key *hmacpb.HmacPrfKey) error {
446	if format.KeySize != uint32(len(key.KeyValue)) ||
447		key.Params.Hash != format.Params.Hash {
448		return fmt.Errorf("key format and generated key do not match")
449	}
450	p, err := subtle.NewHMACPRF(commonpb.HashType_name[int32(key.Params.Hash)], key.KeyValue)
451	if err != nil {
452		return fmt.Errorf("cannot create primitive from key: %s", err)
453	}
454	return validateHMACPRFPrimitive(p, key)
455}
456
457// validateHMACPRFPrimitive checks whether the given primitive can compute a PRF of length 16
458func validateHMACPRFPrimitive(p interface{}, key *hmacpb.HmacPrfKey) error {
459	hmac := p.(prf.PRF)
460	prfPrimitive, err := subtle.NewHMACPRF(commonpb.HashType_name[int32(key.Params.Hash)], key.KeyValue)
461	if err != nil {
462		return fmt.Errorf("Could not create HMAC PRF with key material %q: %s", hex.EncodeToString(key.KeyValue), err)
463	}
464	data := random.GetRandomBytes(20)
465	res, err := hmac.ComputePRF(data, 16)
466	if err != nil {
467		return fmt.Errorf("prf computation failed: %s", err)
468	}
469	if len(res) != 16 {
470		return fmt.Errorf("prf computation did not produce 16 byte output")
471	}
472	res2, err := prfPrimitive.ComputePRF(data, 16)
473	if err != nil {
474		return fmt.Errorf("prf computation failed: %s", err)
475	}
476	if len(res2) != 16 {
477		return fmt.Errorf("prf computation did not produce 16 byte output")
478	}
479	if hex.EncodeToString(res) != hex.EncodeToString(res2) {
480		return fmt.Errorf("prf computation did not produce the same output for the same key and input")
481	}
482	return nil
483}
484