xref: /aosp_15_r20/external/tink/go/keyderivation/prf_based_deriver_key_manager_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package keyderivation_test
18
19import (
20	"fmt"
21	"testing"
22
23	"github.com/google/go-cmp/cmp"
24	"google.golang.org/protobuf/proto"
25	"google.golang.org/protobuf/testing/protocmp"
26	"github.com/google/tink/go/aead"
27	"github.com/google/tink/go/core/registry"
28	"github.com/google/tink/go/keyderivation"
29	"github.com/google/tink/go/prf"
30	"github.com/google/tink/go/subtle/random"
31	aesgcmpb "github.com/google/tink/go/proto/aes_gcm_go_proto"
32	prfderpb "github.com/google/tink/go/proto/prf_based_deriver_go_proto"
33	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
34)
35
36const (
37	prfBasedDeriverKeyVersion = 0
38	prfBasedDeriverTypeURL    = "type.googleapis.com/google.crypto.tink.PrfBasedDeriverKey"
39)
40
41func TestPRFBasedDeriverKeyManagerPrimitive(t *testing.T) {
42	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
43	if err != nil {
44		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
45	}
46	prfs := []struct {
47		name     string
48		template *tinkpb.KeyTemplate
49	}{
50		{
51			name:     "HKDF-SHA256",
52			template: prf.HKDFSHA256PRFKeyTemplate(),
53		},
54	}
55	derivations := []struct {
56		name     string
57		template *tinkpb.KeyTemplate
58	}{
59		{
60			name:     "AES128GCM",
61			template: aead.AES128GCMKeyTemplate(),
62		},
63		{
64			name:     "AES256GCM",
65			template: aead.AES256GCMKeyTemplate(),
66		},
67		{
68			name:     "AES256GCMNoPrefix",
69			template: aead.AES256GCMNoPrefixKeyTemplate(),
70		},
71	}
72	for _, prf := range prfs {
73		for _, der := range derivations {
74			for _, salt := range [][]byte{nil, []byte("salt")} {
75				name := fmt.Sprintf("%s_%s", prf.name, der.name)
76				if salt != nil {
77					name += "_with_salt"
78				}
79				t.Run(name, func(t *testing.T) {
80					prfKey, err := registry.NewKeyData(prf.template)
81					if err != nil {
82						t.Fatalf("registry.NewKeyData() err = %v, want nil", err)
83					}
84					key := &prfderpb.PrfBasedDeriverKey{
85						Version: 0,
86						PrfKey:  prfKey,
87						Params: &prfderpb.PrfBasedDeriverParams{
88							DerivedKeyTemplate: der.template,
89						},
90					}
91					serializedKey, err := proto.Marshal(key)
92					if err != nil {
93						t.Fatalf("proto.Marshal(%v) err = %v, want nil", key, err)
94					}
95					p, err := km.Primitive(serializedKey)
96					if err != nil {
97						t.Fatalf("Primitive() err = %v, want nil", err)
98					}
99					d, ok := p.(keyderivation.KeysetDeriver)
100					if !ok {
101						t.Fatal("primitive is not KeysetDeriver")
102					}
103					if _, err := d.DeriveKeyset(salt); err != nil {
104						t.Fatalf("DeriveKeyset() err = %v, want nil", err)
105					}
106					// We cannot test the derived keyset handle because, at this point, it
107					// is filled with placeholder values for the key ID, status, and
108					// output prefix type fields.
109				})
110			}
111		}
112	}
113}
114
115func TestPRFBasedDeriverKeyManagerPrimitiveRejectsIncorrectKeys(t *testing.T) {
116	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
117	if err != nil {
118		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
119	}
120	prfKey, err := registry.NewKeyData(prf.HKDFSHA256PRFKeyTemplate())
121	if err != nil {
122		t.Fatalf("registry.NewKeyData() err = %v, want nil", err)
123	}
124	missingParamsKey := &prfderpb.PrfBasedDeriverKey{
125		Version: prfBasedDeriverKeyVersion,
126		PrfKey:  prfKey,
127	}
128	serializedMissingParamsKey, err := proto.Marshal(missingParamsKey)
129	if err != nil {
130		t.Fatalf("proto.Marshal(%v) err = %v, want nil", serializedMissingParamsKey, err)
131	}
132	aesGCMKey := &aesgcmpb.AesGcmKey{Version: 0, KeyValue: random.GetRandomBytes(32)}
133	serializedAESGCMKey, err := proto.Marshal(aesGCMKey)
134	if err != nil {
135		t.Fatalf("proto.Marshal(%v) err = %v, want nil", aesGCMKey, err)
136	}
137	for _, test := range []struct {
138		name          string
139		serializedKey []byte
140	}{
141		{
142			name: "nil key",
143		},
144		{
145			name:          "zero-length key",
146			serializedKey: []byte{},
147		},
148		{
149			name:          "missing params",
150			serializedKey: serializedMissingParamsKey,
151		},
152		{
153			name:          "wrong key type",
154			serializedKey: serializedAESGCMKey,
155		},
156	} {
157		t.Run(test.name, func(t *testing.T) {
158			if _, err := km.Primitive(test.serializedKey); err == nil {
159				t.Error("Primitive() err = nil, want non-nil")
160			}
161		})
162	}
163}
164
165func TestPRFBasedDeriverKeyManagerPrimitiveRejectsInvalidKeys(t *testing.T) {
166	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
167	if err != nil {
168		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
169	}
170
171	validPRFKey, err := registry.NewKeyData(prf.HKDFSHA256PRFKeyTemplate())
172	if err != nil {
173		t.Fatalf("registry.NewKeyData() err = %v, want nil", err)
174	}
175	validKey := &prfderpb.PrfBasedDeriverKey{
176		Version: 0,
177		PrfKey:  validPRFKey,
178		Params: &prfderpb.PrfBasedDeriverParams{
179			DerivedKeyTemplate: aead.AES128GCMKeyTemplate(),
180		},
181	}
182	serializedValidKey, err := proto.Marshal(validKey)
183	if err != nil {
184		t.Fatalf("proto.Marshal(%v) err = %v, want nil", validKey, err)
185	}
186	if _, err := km.Primitive(serializedValidKey); err != nil {
187		t.Errorf("Primitive() err = %v, want nil", err)
188	}
189
190	invalidPRFKey, err := registry.NewKeyData(aead.AES128GCMKeyTemplate())
191	if err != nil {
192		t.Fatalf("registry.NewKeyData() err = %v, want nil", err)
193	}
194
195	for _, test := range []struct {
196		name           string
197		version        uint32
198		prfKey         *tinkpb.KeyData
199		derKeyTemplate *tinkpb.KeyTemplate
200	}{
201		{
202			name:           "invalid version",
203			version:        100,
204			prfKey:         validKey.GetPrfKey(),
205			derKeyTemplate: validKey.GetParams().GetDerivedKeyTemplate(),
206		},
207		{
208			name:           "invalid PRF key",
209			version:        validKey.GetVersion(),
210			prfKey:         invalidPRFKey,
211			derKeyTemplate: validKey.GetParams().GetDerivedKeyTemplate(),
212		},
213		{
214			name:           "invalid derived key template",
215			version:        validKey.GetVersion(),
216			prfKey:         validKey.GetPrfKey(),
217			derKeyTemplate: aead.AES128CTRHMACSHA256KeyTemplate(),
218		},
219	} {
220		t.Run(test.name, func(t *testing.T) {
221			key := &prfderpb.PrfBasedDeriverKey{
222				Version: test.version,
223				PrfKey:  test.prfKey,
224				Params: &prfderpb.PrfBasedDeriverParams{
225					DerivedKeyTemplate: test.derKeyTemplate,
226				},
227			}
228			serializedKey, err := proto.Marshal(key)
229			if err != nil {
230				t.Fatalf("proto.Marshal(%v) err = %v, want nil", key, err)
231			}
232			if _, err := km.Primitive(serializedKey); err == nil {
233				t.Error("Primitive() err = nil, want non-nil")
234			}
235		})
236	}
237}
238
239func TestPRFBasedDeriverKeyManagerNewKey(t *testing.T) {
240	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
241	if err != nil {
242		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
243	}
244	prfs := []struct {
245		name     string
246		template *tinkpb.KeyTemplate
247	}{
248		{
249			name:     "HKDF-SHA256",
250			template: prf.HKDFSHA256PRFKeyTemplate(),
251		},
252	}
253	derivations := []struct {
254		name     string
255		template *tinkpb.KeyTemplate
256	}{
257		{
258			name:     "AES128GCM",
259			template: aead.AES128GCMKeyTemplate(),
260		},
261		{
262			name:     "AES256GCM",
263			template: aead.AES256GCMKeyTemplate(),
264		},
265		{
266			name:     "AES256GCMNoPrefix",
267			template: aead.AES256GCMNoPrefixKeyTemplate(),
268		},
269	}
270	for _, prf := range prfs {
271		for _, der := range derivations {
272			for _, salt := range [][]byte{nil, []byte("salt")} {
273				name := fmt.Sprintf("%s_%s", prf.name, der.name)
274				if salt != nil {
275					name += "_with_salt"
276				}
277				t.Run(name, func(t *testing.T) {
278					keyFormat := &prfderpb.PrfBasedDeriverKeyFormat{
279						PrfKeyTemplate: prf.template,
280						Params: &prfderpb.PrfBasedDeriverParams{
281							DerivedKeyTemplate: der.template,
282						},
283					}
284					serializedKeyFormat, err := proto.Marshal(keyFormat)
285					if err != nil {
286						t.Fatalf("proto.Marshal(%v) err = %v, want nil", keyFormat, err)
287					}
288					k, err := km.NewKey(serializedKeyFormat)
289					if err != nil {
290						t.Errorf("NewKey() err = %v, want nil", err)
291					}
292					key, ok := k.(*prfderpb.PrfBasedDeriverKey)
293					if !ok {
294						t.Fatal("key is not PrfBasedDeriverKey")
295					}
296					if key.GetVersion() != prfBasedDeriverKeyVersion {
297						t.Errorf("GetVersion() = %d, want 0", key.GetVersion())
298					}
299					prfKeyData := key.GetPrfKey()
300					if got, want := prfKeyData.GetTypeUrl(), prf.template.GetTypeUrl(); got != want {
301						t.Errorf("GetTypeUrl() = %q, want %q", got, want)
302					}
303					if got, want := prfKeyData.GetKeyMaterialType(), tinkpb.KeyData_SYMMETRIC; got != want {
304						t.Errorf("GetKeyMaterialType() = %s, want %s", got, want)
305					}
306					if diff := cmp.Diff(key.GetParams().GetDerivedKeyTemplate(), der.template, protocmp.Transform()); diff != "" {
307						t.Errorf("GetDerivedKeyTemplate() diff = %s", diff)
308					}
309				})
310			}
311		}
312	}
313}
314
315func TestPRFBasedDeriverKeyManagerNewKeyData(t *testing.T) {
316	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
317	if err != nil {
318		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
319	}
320	prfs := []struct {
321		name     string
322		template *tinkpb.KeyTemplate
323	}{
324		{
325			name:     "HKDF-SHA256",
326			template: prf.HKDFSHA256PRFKeyTemplate(),
327		},
328	}
329	derivations := []struct {
330		name     string
331		template *tinkpb.KeyTemplate
332	}{
333		{
334			name:     "AES128GCM",
335			template: aead.AES128GCMKeyTemplate(),
336		},
337		{
338			name:     "AES256GCM",
339			template: aead.AES256GCMKeyTemplate(),
340		},
341		{
342			name:     "AES256GCMNoPrefix",
343			template: aead.AES256GCMNoPrefixKeyTemplate(),
344		},
345	}
346	for _, prf := range prfs {
347		for _, der := range derivations {
348			for _, salt := range [][]byte{nil, []byte("salt")} {
349				name := fmt.Sprintf("%s_%s", prf.name, der.name)
350				if salt != nil {
351					name += "_with_salt"
352				}
353				t.Run(name, func(t *testing.T) {
354					keyFormat := &prfderpb.PrfBasedDeriverKeyFormat{
355						PrfKeyTemplate: prf.template,
356						Params: &prfderpb.PrfBasedDeriverParams{
357							DerivedKeyTemplate: der.template,
358						},
359					}
360					serializedKeyFormat, err := proto.Marshal(keyFormat)
361					if err != nil {
362						t.Fatalf("proto.Marshal(%v) err = %v, want nil", keyFormat, err)
363					}
364					keyData, err := km.NewKeyData(serializedKeyFormat)
365					if err != nil {
366						t.Errorf("NewKeyData() err = %v, want nil", err)
367					}
368					if keyData.GetTypeUrl() != prfBasedDeriverTypeURL {
369						t.Errorf("GetTypeUrl() = %s, want %s", keyData.GetTypeUrl(), prfBasedDeriverTypeURL)
370					}
371					if keyData.GetKeyMaterialType() != tinkpb.KeyData_SYMMETRIC {
372						t.Errorf("GetKeyMaterialType() = %s, want %s", keyData.GetKeyMaterialType(), tinkpb.KeyData_SYMMETRIC)
373					}
374					key := &prfderpb.PrfBasedDeriverKey{}
375					if err := proto.Unmarshal(keyData.GetValue(), key); err != nil {
376						t.Fatalf("proto.Unmarshal() err = %v, want nil", err)
377					}
378					if key.GetVersion() != prfBasedDeriverKeyVersion {
379						t.Errorf("GetVersion() = %d, want %d", key.GetVersion(), prfBasedDeriverKeyVersion)
380					}
381					prfKeyData := key.GetPrfKey()
382					if got, want := prfKeyData.GetTypeUrl(), prf.template.GetTypeUrl(); got != want {
383						t.Errorf("GetTypeUrl() = %q, want %q", got, want)
384					}
385					if got, want := prfKeyData.GetKeyMaterialType(), tinkpb.KeyData_SYMMETRIC; got != want {
386						t.Errorf("GetKeyMaterialType() = %s, want %s", got, want)
387					}
388					if diff := cmp.Diff(key.GetParams().GetDerivedKeyTemplate(), der.template, protocmp.Transform()); diff != "" {
389						t.Errorf("GetDerivedKeyTemplate() diff = %s", diff)
390					}
391				})
392			}
393		}
394	}
395}
396
397func TestPRFBasedDeriverKeyManagerNewKeyAndNewKeyDataRejectsIncorrectKeyFormats(t *testing.T) {
398	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
399	if err != nil {
400		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
401	}
402	missingParamsKeyFormat := &prfderpb.PrfBasedDeriverKeyFormat{
403		PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(),
404	}
405	serializedMissingParamsKeyFormat, err := proto.Marshal(missingParamsKeyFormat)
406	if err != nil {
407		t.Fatalf("proto.Marshal(%v) err = %v, want nil", missingParamsKeyFormat, err)
408	}
409	aesGCMKeyFormat := &aesgcmpb.AesGcmKeyFormat{KeySize: 32, Version: 0}
410	serializedAESGCMKeyFormat, err := proto.Marshal(aesGCMKeyFormat)
411	if err != nil {
412		t.Fatalf("proto.Marshal(%v) err = %v, want nil", aesGCMKeyFormat, err)
413	}
414	for _, test := range []struct {
415		name                string
416		serializedKeyFormat []byte
417	}{
418		{
419			name: "nil key",
420		},
421		{
422			name:                "zero-length key",
423			serializedKeyFormat: []byte{},
424		},
425		{
426			name:                "missing params",
427			serializedKeyFormat: serializedMissingParamsKeyFormat,
428		},
429		{
430			name:                "wrong key type",
431			serializedKeyFormat: serializedAESGCMKeyFormat,
432		},
433	} {
434		t.Run(test.name, func(t *testing.T) {
435			if _, err := km.NewKey(test.serializedKeyFormat); err == nil {
436				t.Error("NewKey() err = nil, want non-nil")
437			}
438			if _, err := km.NewKeyData(test.serializedKeyFormat); err == nil {
439				t.Error("NewKeyData() err = nil, want non-nil")
440			}
441		})
442	}
443}
444
445func TestPRFBasedDeriverKeyManagerNewKeyAndNewKeyDataRejectsInvalidKeyFormats(t *testing.T) {
446	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
447	if err != nil {
448		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
449	}
450
451	validKeyFormat := &prfderpb.PrfBasedDeriverKeyFormat{
452		PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(),
453		Params: &prfderpb.PrfBasedDeriverParams{
454			DerivedKeyTemplate: aead.AES128GCMKeyTemplate(),
455		},
456	}
457	serializedValidKeyFormat, err := proto.Marshal(validKeyFormat)
458	if err != nil {
459		t.Fatalf("proto.Marshal(%v) err = %v, want nil", validKeyFormat, err)
460	}
461	if _, err := km.NewKey(serializedValidKeyFormat); err != nil {
462		t.Errorf("Primitive() err = %v, want nil", err)
463	}
464
465	for _, test := range []struct {
466		name           string
467		prfKeyTemplate *tinkpb.KeyTemplate
468		derKeyTemplate *tinkpb.KeyTemplate
469	}{
470		{
471			"invalid PRF key template",
472			aead.AES128GCMKeyTemplate(),
473			validKeyFormat.GetParams().GetDerivedKeyTemplate(),
474		},
475		{
476			"invalid derived key template",
477			validKeyFormat.GetPrfKeyTemplate(),
478			aead.AES128CTRHMACSHA256KeyTemplate(),
479		},
480	} {
481		t.Run(test.name, func(t *testing.T) {
482			keyFormat := &prfderpb.PrfBasedDeriverKeyFormat{
483				PrfKeyTemplate: test.prfKeyTemplate,
484				Params: &prfderpb.PrfBasedDeriverParams{
485					DerivedKeyTemplate: test.derKeyTemplate,
486				},
487			}
488			serializedKeyFormat, err := proto.Marshal(keyFormat)
489			if err != nil {
490				t.Fatalf("proto.Marshal(%v) err = %v, want nil", keyFormat, err)
491			}
492			if _, err := km.NewKey(serializedKeyFormat); err == nil {
493				t.Error("NewKey() err = nil, want non-nil")
494			}
495			if _, err := km.NewKeyData(serializedKeyFormat); err == nil {
496				t.Error("NewKeyData() err = nil, want non-nil")
497			}
498		})
499	}
500}
501
502func TestPRFBasedDeriverKeyManagerDoesSupport(t *testing.T) {
503	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
504	if err != nil {
505		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
506	}
507	if !km.DoesSupport(prfBasedDeriverTypeURL) {
508		t.Errorf("DoesSupport(%q) = false, want true", prfBasedDeriverTypeURL)
509	}
510	if unsupported := "unsupported.key.type"; km.DoesSupport(unsupported) {
511		t.Errorf("DoesSupport(%q) = true, want false", unsupported)
512	}
513}
514
515func TestPRFBasedDeriverKeyManagerTypeURL(t *testing.T) {
516	km, err := registry.GetKeyManager(prfBasedDeriverTypeURL)
517	if err != nil {
518		t.Fatalf("GetKeyManager(%q) err = %v, want nil", prfBasedDeriverTypeURL, err)
519	}
520	if km.TypeURL() != prfBasedDeriverTypeURL {
521		t.Errorf("TypeURL() = %q, want %q", km.TypeURL(), prfBasedDeriverTypeURL)
522	}
523}
524