xref: /aosp_15_r20/external/tink/go/jwt/jwt_hmac_key_manager_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package jwt
18
19import (
20	"encoding/base64"
21	"testing"
22	"time"
23
24	"github.com/google/go-cmp/cmp"
25	"google.golang.org/protobuf/proto"
26	"github.com/google/tink/go/core/registry"
27	"github.com/google/tink/go/subtle/random"
28	jwtmacpb "github.com/google/tink/go/proto/jwt_hmac_go_proto"
29	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
30)
31
32type jwtKeyManagerTestCase struct {
33	tag       string
34	keyFormat *jwtmacpb.JwtHmacKeyFormat
35	key       *jwtmacpb.JwtHmacKey
36}
37
38const (
39	typeURL = "type.googleapis.com/google.crypto.tink.JwtHmacKey"
40)
41
42func generateKeyFormat(keySize uint32, algorithm jwtmacpb.JwtHmacAlgorithm) *jwtmacpb.JwtHmacKeyFormat {
43	return &jwtmacpb.JwtHmacKeyFormat{
44		KeySize:   keySize,
45		Algorithm: algorithm,
46	}
47}
48
49func TestDoesSupport(t *testing.T) {
50	km, err := registry.GetKeyManager(typeURL)
51	if err != nil {
52		t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err)
53	}
54	if !km.DoesSupport(typeURL) {
55		t.Errorf("km.DoesSupport(%q) = false, want true", typeURL)
56	}
57}
58
59func TestTypeURL(t *testing.T) {
60	km, err := registry.GetKeyManager(typeURL)
61	if err != nil {
62		t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err)
63	}
64	if km.TypeURL() != typeURL {
65		t.Errorf("km.TypeURL() = %q, want %q", km.TypeURL(), typeURL)
66	}
67}
68
69var invalidKeyFormatTestCases = []jwtKeyManagerTestCase{
70	{
71		tag:       "invalid hash algorithm",
72		keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN),
73	},
74	{
75		tag:       "invalid HS256 key size",
76		keyFormat: generateKeyFormat(31, jwtmacpb.JwtHmacAlgorithm_HS256),
77	},
78	{
79		tag:       "invalid HS384 key size",
80		keyFormat: generateKeyFormat(47, jwtmacpb.JwtHmacAlgorithm_HS384),
81	},
82	{
83		tag:       "invalid HS512 key size",
84		keyFormat: generateKeyFormat(63, jwtmacpb.JwtHmacAlgorithm_HS512),
85	},
86	{
87		tag:       "empty key format",
88		keyFormat: &jwtmacpb.JwtHmacKeyFormat{},
89	},
90	{
91		tag:       "nil key format",
92		keyFormat: nil,
93	},
94}
95
96func TestNewKeyInvalidFormatFails(t *testing.T) {
97	for _, tc := range invalidKeyFormatTestCases {
98		t.Run(tc.tag, func(t *testing.T) {
99			km, err := registry.GetKeyManager(typeURL)
100			if err != nil {
101				t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
102			}
103			serializedKeyFormat, err := proto.Marshal(tc.keyFormat)
104			if err != nil {
105				t.Errorf("serializing key format: %v", err)
106			}
107			if _, err := km.NewKey(serializedKeyFormat); err == nil {
108				t.Errorf("km.NewKey() err = nil, want error")
109			}
110		})
111	}
112}
113
114func TestNewDataInvalidFormatFails(t *testing.T) {
115	for _, tc := range invalidKeyFormatTestCases {
116		t.Run(tc.tag, func(t *testing.T) {
117			km, err := registry.GetKeyManager(typeURL)
118			if err != nil {
119				t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
120			}
121			serializedKeyFormat, err := proto.Marshal(tc.keyFormat)
122			if err != nil {
123				t.Errorf("serializing key format: %v", err)
124			}
125			if _, err := km.NewKeyData(serializedKeyFormat); err == nil {
126				t.Errorf("km.NewKey() err = nil, want error")
127			}
128		})
129	}
130}
131
132var validKeyFormatTestCases = []jwtKeyManagerTestCase{
133	{
134		tag:       "SHA256 hash algorithm",
135		keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256),
136	},
137	{
138		tag:       "SHA384 hash algorithm",
139		keyFormat: generateKeyFormat(48, jwtmacpb.JwtHmacAlgorithm_HS384),
140	},
141	{
142		tag:       "SHA512 hash algorithm",
143		keyFormat: generateKeyFormat(64, jwtmacpb.JwtHmacAlgorithm_HS512),
144	},
145}
146
147func TestNewKey(t *testing.T) {
148	for _, tc := range validKeyFormatTestCases {
149		t.Run(tc.tag, func(t *testing.T) {
150			km, err := registry.GetKeyManager(typeURL)
151			if err != nil {
152				t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
153			}
154			serializedKeyFormat, err := proto.Marshal(tc.keyFormat)
155			if err != nil {
156				t.Errorf("serializing key format: %v", err)
157			}
158			k, err := km.NewKey(serializedKeyFormat)
159			if err != nil {
160				t.Errorf("km.NewKey() err = %v, want nil", err)
161			}
162			key, ok := k.(*jwtmacpb.JwtHmacKey)
163			if !ok {
164				t.Errorf("key isn't of type JwtHmacKey")
165			}
166			if key.Algorithm != tc.keyFormat.Algorithm {
167				t.Errorf("k.Algorithm = %v, want %v", key.Algorithm, tc.keyFormat.Algorithm)
168			}
169			if len(key.KeyValue) != int(tc.keyFormat.KeySize) {
170				t.Errorf("len(key.KeyValue) = %d, want %d", len(key.KeyValue), tc.keyFormat.KeySize)
171			}
172		})
173	}
174}
175
176func TestNewKeyData(t *testing.T) {
177	for _, tc := range validKeyFormatTestCases {
178		t.Run(tc.tag, func(t *testing.T) {
179			km, err := registry.GetKeyManager(typeURL)
180			if err != nil {
181				t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
182			}
183			serializedKeyFormat, err := proto.Marshal(tc.keyFormat)
184			if err != nil {
185				t.Errorf("serializing key format: %v", err)
186			}
187			k, err := km.NewKeyData(serializedKeyFormat)
188			if err != nil {
189				t.Errorf("km.NewKeyData() err = %v, want nil", err)
190			}
191			if k.GetTypeUrl() != typeURL {
192				t.Errorf("k.GetTypeUrl() = %q, want %q", k.GetTypeUrl(), typeURL)
193			}
194			if k.GetKeyMaterialType() != tinkpb.KeyData_SYMMETRIC {
195				t.Errorf("k.GetKeyMaterialType() = %q, want %q", k.GetKeyMaterialType(), tinkpb.KeyData_SYMMETRIC)
196			}
197		})
198	}
199}
200
201func generateKey(keySize, version uint32, algorithm jwtmacpb.JwtHmacAlgorithm, kid *jwtmacpb.JwtHmacKey_CustomKid) *jwtmacpb.JwtHmacKey {
202	return &jwtmacpb.JwtHmacKey{
203		KeyValue:  random.GetRandomBytes(keySize),
204		Algorithm: algorithm,
205		CustomKid: kid,
206		Version:   version,
207	}
208}
209
210func TestGetPrimitiveWithValidKeys(t *testing.T) {
211	rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true, Audiences: []string{"tink-aud"}})
212	if err != nil {
213		t.Fatalf("NewRawJWT() err = %v, want nil", err)
214	}
215	validator, err := NewValidator(&ValidatorOpts{AllowMissingExpiration: true, ExpectedAudience: refString("tink-aud")})
216	if err != nil {
217		t.Fatalf("NewValidator() err = %v, want nil", err)
218	}
219	for _, tc := range []jwtKeyManagerTestCase{
220		{
221			tag: "SHA256 hash algorithm",
222			key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil),
223		},
224		{
225			tag: "SHA384 hash algorithm",
226			key: generateKey(48, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil),
227		},
228		{
229			tag: "SHA512 hash algorithm",
230			key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil),
231		},
232		{
233			tag: "with custom kid",
234			key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"}),
235		},
236	} {
237		t.Run(tc.tag, func(t *testing.T) {
238			km, err := registry.GetKeyManager(typeURL)
239			if err != nil {
240				t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
241			}
242			serializedKey, err := proto.Marshal(tc.key)
243			if err != nil {
244				t.Errorf("serializing key format: %v", err)
245			}
246			p, err := km.Primitive(serializedKey)
247			if err != nil {
248				t.Errorf("km.Primitive() err = %v, want nil", err)
249			}
250			primitive, ok := p.(*macWithKID)
251			if !ok {
252				t.Errorf("primitive isn't of type: macWithKID")
253			}
254			compact, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, nil)
255			if err != nil {
256				t.Errorf("ComputeMACAndEncodeWithKID() err = %v, want nil", err)
257			}
258			verifiedJWT, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil)
259			if err != nil {
260				t.Errorf("VerifyMACAndDecodeWithKID() err = %v, want nil", err)
261			}
262			audiences, err := verifiedJWT.Audiences()
263			if err != nil {
264				t.Errorf("verifiedJWT.Audiences() err = %v, want nil", err)
265			}
266			if !cmp.Equal(audiences, []string{"tink-aud"}) {
267				t.Errorf("verifiedJWT.Audiences() = %q, want ['tink-aud']", audiences)
268			}
269
270		})
271	}
272}
273
274func TestGetPrimitiveWithInvalidKeys(t *testing.T) {
275	for _, tc := range []jwtKeyManagerTestCase{
276		{
277			tag: "HS256",
278			key: generateKey(31, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil),
279		},
280		{
281			tag: "HS384",
282			key: generateKey(47, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil),
283		},
284		{
285			tag: "HS512",
286			key: generateKey(63, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil),
287		},
288	} {
289		t.Run(tc.tag, func(t *testing.T) {
290			km, err := registry.GetKeyManager(typeURL)
291			if err != nil {
292				t.Fatalf("registry.GetKeyManager(%q) err=%q, want nil", typeURL, err)
293			}
294			serializedKey, err := proto.Marshal(tc.key)
295			if err != nil {
296				t.Fatalf("proto.Marshal(tc.key) err =%q, want nil", err)
297			}
298			_, err = km.Primitive(serializedKey)
299			if err == nil {
300				t.Error("km.Primitive(serializedKey) err = nil, want error")
301			}
302		})
303	}
304}
305
306func TestSpecyfingCustomKIDAndTINKKIDFails(t *testing.T) {
307	// key and compact are examples from: https://datatracker.ietf.org/doc/html/rfc7515#appendix-A.1.1
308	compact := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
309	rawKey, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString("AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow")
310	if err != nil {
311		t.Fatalf("failed decoding test key: %v", err)
312	}
313	key := &jwtmacpb.JwtHmacKey{
314		KeyValue:  rawKey,
315		Algorithm: jwtmacpb.JwtHmacAlgorithm_HS256,
316		CustomKid: &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"},
317		Version:   0,
318	}
319	km, err := registry.GetKeyManager(typeURL)
320	if err != nil {
321		t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
322	}
323	serializedKey, err := proto.Marshal(key)
324	if err != nil {
325		t.Errorf("serializing key format: %v", err)
326	}
327	p, err := km.Primitive(serializedKey)
328	if err != nil {
329		t.Errorf("km.Primitive() err = %v, want nil", err)
330	}
331	primitive, ok := p.(*macWithKID)
332	if !ok {
333		t.Errorf("primitive isn't of type: macWithKID")
334	}
335
336	rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true})
337	if err != nil {
338		t.Errorf("creating new RawJWT: %v", err)
339	}
340	opts := &ValidatorOpts{
341		ExpectedTypeHeader: refString("JWT"),
342		ExpectedIssuer:     refString("joe"),
343		FixedNow:           time.Unix(12345, 0),
344	}
345	validator, err := NewValidator(opts)
346	if err != nil {
347		t.Errorf("creating new JWTValidator: %v", err)
348	}
349	if _, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, refString("4566")); err == nil {
350		t.Errorf("primitive.ComputeMACAndEncodeWithKID() err = nil, want error")
351	}
352	if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, refString("4566")); err == nil {
353		t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = 4566) err = nil, want error")
354	}
355	// Verify success without KID
356	if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil); err != nil {
357		t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = nil) err = %v, want nil", err)
358	}
359}
360
361func TestGetPrimitiveWithInvalidKeyFails(t *testing.T) {
362	for _, tc := range []jwtKeyManagerTestCase{
363		{
364			tag: "empty key",
365			key: &jwtmacpb.JwtHmacKey{},
366		},
367		{
368			tag: "nil key",
369			key: nil,
370		},
371		{
372			tag: "unsupported hash algorithm",
373			key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN, nil),
374		},
375		{
376			tag: "short key length",
377			key: generateKey(20, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil),
378		},
379		{
380			tag: "unsupported version",
381			key: generateKey(48, 1, jwtmacpb.JwtHmacAlgorithm_HS384, nil),
382		},
383	} {
384		t.Run(tc.tag, func(t *testing.T) {
385			km, err := registry.GetKeyManager(typeURL)
386			if err != nil {
387				t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
388			}
389			serializedKey, err := proto.Marshal(tc.key)
390			if err != nil {
391				t.Errorf("serializing key format: %v", err)
392			}
393			if _, err := km.Primitive(serializedKey); err == nil {
394				t.Errorf("km.Primitive() err = nil, want error")
395			}
396		})
397	}
398}
399
400func TestGeneratesDifferentKeys(t *testing.T) {
401	km, err := registry.GetKeyManager(typeURL)
402	if err != nil {
403		t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err)
404	}
405	serializedKeyFormat, err := proto.Marshal(generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256))
406	if err != nil {
407		t.Errorf("serializing key format: %v", err)
408	}
409	k1, err := km.NewKey(serializedKeyFormat)
410	if err != nil {
411		t.Errorf("km.NewKey() err = %v, want nil", err)
412	}
413	k2, err := km.NewKey(serializedKeyFormat)
414	if err != nil {
415		t.Errorf("km.NewKey() err = %v, want nil", err)
416	}
417	key1, ok := k1.(*jwtmacpb.JwtHmacKey)
418	if !ok {
419		t.Errorf("k1 isn't of type JwtHmacKey")
420	}
421	key2, ok := k2.(*jwtmacpb.JwtHmacKey)
422	if !ok {
423		t.Errorf("k2 isn't of type JwtHmacKey")
424	}
425	if cmp.Equal(key1.GetKeyValue(), key2.GetKeyValue()) {
426		t.Errorf("key material should differ")
427	}
428}
429