xref: /aosp_15_r20/external/tink/go/jwt/verified_jwt_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package jwt_test
18
19import (
20	"testing"
21	"time"
22
23	"github.com/google/go-cmp/cmp"
24	"github.com/google/go-cmp/cmp/cmpopts"
25	"github.com/google/tink/go/jwt"
26	"github.com/google/tink/go/keyset"
27)
28
29func createVerifiedJWT(rawJWT *jwt.RawJWT) (*jwt.VerifiedJWT, error) {
30	kh, err := keyset.NewHandle(jwt.HS256Template())
31	if err != nil {
32		return nil, err
33	}
34	m, err := jwt.NewMAC(kh)
35	if err != nil {
36		return nil, err
37	}
38	compact, err := m.ComputeMACAndEncode(rawJWT)
39	if err != nil {
40		return nil, err
41	}
42	// This validator is purposely instantiated to always pass.
43	// It isn't really validating much and probably shouldn't
44	// be used like this out side of these tests.
45	opts := &jwt.ValidatorOpts{
46		AllowMissingExpiration: true,
47		IgnoreTypeHeader:       true,
48		IgnoreAudiences:        true,
49		IgnoreIssuer:           true,
50	}
51	issuedAt, err := rawJWT.IssuedAt()
52	if err == nil {
53		opts.FixedNow = issuedAt
54	}
55
56	validator, err := jwt.NewValidator(opts)
57	if err != nil {
58		return nil, err
59	}
60	return m.VerifyMACAndDecode(compact, validator)
61}
62
63func TestGetRegisteredStringClaims(t *testing.T) {
64	opts := &jwt.RawJWTOptions{
65		TypeHeader:        refString("typeHeader"),
66		Subject:           refString("test-subject"),
67		Issuer:            refString("test-issuer"),
68		JWTID:             refString("1"),
69		WithoutExpiration: true,
70	}
71	rawJWT, err := jwt.NewRawJWT(opts)
72	if err != nil {
73		t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err)
74	}
75	verifiedJWT, err := createVerifiedJWT(rawJWT)
76	if err != nil {
77		t.Fatalf("creating verifiedJWT: %v", err)
78	}
79	if !verifiedJWT.HasTypeHeader() {
80		t.Errorf("verifiedJWT.HasTypeHeader() = false, want true")
81	}
82	if !verifiedJWT.HasSubject() {
83		t.Errorf("verifiedJWT.HasSubject() = false, want true")
84	}
85	if !verifiedJWT.HasIssuer() {
86		t.Errorf("verifiedJWT.HasIssuer() = false, want true")
87	}
88	if !verifiedJWT.HasJWTID() {
89		t.Errorf("verifiedJWT.HasJWTID() = false, want true")
90	}
91	typeHeader, err := verifiedJWT.TypeHeader()
92	if err != nil {
93		t.Errorf("verifiedJWT.TypeHeader() err = %v, want nil", err)
94	}
95	if !cmp.Equal(typeHeader, *opts.TypeHeader) {
96		t.Errorf("verifiedJWT.TypeHeader() = %q, want %q", typeHeader, *opts.TypeHeader)
97	}
98	subject, err := verifiedJWT.Subject()
99	if err != nil {
100		t.Errorf("verifiedJWT.Subject() err = %v, want nil", err)
101	}
102	if !cmp.Equal(subject, *opts.Subject) {
103		t.Errorf("verifiedJWT.Subject() = %q, want %q", subject, *opts.Subject)
104	}
105	issuer, err := verifiedJWT.Issuer()
106	if err != nil {
107		t.Errorf("verifiedJWT.Issuer() err = %v, want nil", err)
108	}
109	if !cmp.Equal(issuer, *opts.Issuer) {
110		t.Errorf("verifiedJWT.Issuer() = %q, want %q", issuer, *opts.Issuer)
111	}
112	jwtID, err := verifiedJWT.JWTID()
113	if err != nil {
114		t.Errorf("verifiedJWT.JWTID() err = %v, want nil", err)
115	}
116	if !cmp.Equal(jwtID, *opts.JWTID) {
117		t.Errorf("verifiedJWT.JWTID() = %q, want %q", jwtID, *opts.JWTID)
118	}
119	if !cmp.Equal(verifiedJWT.CustomClaimNames(), []string{}) {
120		t.Errorf("verifiedJWT.CustomClaimNames() = %q want %q", verifiedJWT.CustomClaimNames(), []string{})
121	}
122}
123
124func TestGetRegisteredTimestampClaims(t *testing.T) {
125	now := time.Now()
126	opts := &jwt.RawJWTOptions{
127		ExpiresAt: refTime(now.Add(time.Hour * 24).Unix()),
128		IssuedAt:  refTime(now.Unix()),
129		NotBefore: refTime(now.Add(-time.Hour * 2).Unix()),
130	}
131	rawJWT, err := jwt.NewRawJWT(opts)
132	if err != nil {
133		t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err)
134	}
135	verifiedJWT, err := createVerifiedJWT(rawJWT)
136	if err != nil {
137		t.Fatalf("creating verifiedJWT: %v", err)
138	}
139	if !verifiedJWT.HasExpiration() {
140		t.Errorf("verifiedJWT.HasExpiration() = false, want true")
141	}
142	if !verifiedJWT.HasIssuedAt() {
143		t.Errorf("verifiedJWT.HasIssuedAt() = false, want true")
144	}
145	if !verifiedJWT.HasNotBefore() {
146		t.Errorf("verifiedJWT.HasNotBefore() = false, want true")
147	}
148	expiresAt, err := verifiedJWT.ExpiresAt()
149	if err != nil {
150		t.Errorf("verifiedJWT.ExpiresAt() err = %v, want nil", err)
151	}
152	if !cmp.Equal(expiresAt, *opts.ExpiresAt) {
153		t.Errorf("verifiedJWT.ExpiresAt() = %q, want %q", expiresAt, *opts.ExpiresAt)
154	}
155	issuedAt, err := verifiedJWT.IssuedAt()
156	if err != nil {
157		t.Errorf("verifiedJWT.IssuedAt() err = %v, want nil", err)
158	}
159	if !cmp.Equal(issuedAt, *opts.IssuedAt) {
160		t.Errorf("verifiedJWT.IssuedAt() = %q, want %q", issuedAt, *opts.IssuedAt)
161	}
162	notBefore, err := verifiedJWT.NotBefore()
163	if err != nil {
164		t.Errorf("verifiedJWT.NotBefore() err = %v, want nil", err)
165	}
166	if !cmp.Equal(notBefore, *opts.NotBefore) {
167		t.Errorf("verifiedJWT.NotBefore() = %q, want %q", notBefore, *opts.NotBefore)
168	}
169}
170
171func TestGetAudiencesClaim(t *testing.T) {
172	opts := &jwt.RawJWTOptions{
173		WithoutExpiration: true,
174		Audiences:         []string{"foo", "bar"},
175	}
176	rawJWT, err := jwt.NewRawJWT(opts)
177	if err != nil {
178		t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err)
179	}
180	verifiedJWT, err := createVerifiedJWT(rawJWT)
181	if err != nil {
182		t.Fatalf("creating verifiedJWT: %v", err)
183	}
184	if !verifiedJWT.HasAudiences() {
185		t.Errorf("verifiedJWT.HasAudiences() = false, want true")
186	}
187	audiences, err := verifiedJWT.Audiences()
188	if err != nil {
189		t.Errorf("verifiedJWT.Audiences() err = %v, want nil", err)
190	}
191	if !cmp.Equal(audiences, opts.Audiences) {
192		t.Errorf("verifiedJWT.Audiences() = %q, want %q", audiences, opts.Audiences)
193	}
194}
195
196func TestGetCustomClaims(t *testing.T) {
197	opts := &jwt.RawJWTOptions{
198		WithoutExpiration: true,
199		CustomClaims: map[string]interface{}{
200			"cc-null":   nil,
201			"cc-num":    1.67,
202			"cc-bool":   true,
203			"cc-string": "goo",
204			"cc-array":  []interface{}{"1", "2", "3"},
205			"cc-object": map[string]interface{}{"cc-nested-num": 5.99},
206		},
207	}
208	rawJWT, err := jwt.NewRawJWT(opts)
209	if err != nil {
210		t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err)
211	}
212	verifiedJWT, err := createVerifiedJWT(rawJWT)
213	if err != nil {
214		t.Fatalf("creating verifiedJWT: %v", err)
215	}
216	wantCustomClaims := []string{"cc-num", "cc-bool", "cc-null", "cc-string", "cc-array", "cc-object"}
217	if !cmp.Equal(verifiedJWT.CustomClaimNames(), wantCustomClaims, cmpopts.SortSlices(func(a, b string) bool { return a < b })) {
218		t.Errorf("verifiedJWT.CustomClaimNames() = %q, want %q", verifiedJWT.CustomClaimNames(), wantCustomClaims)
219	}
220	if !verifiedJWT.HasNullClaim("cc-null") {
221		t.Errorf("verifiedJWT.HasNullClaim('cc-null') = false, want true")
222	}
223	if !verifiedJWT.HasNumberClaim("cc-num") {
224		t.Errorf("verifiedJWT.HasNumberClaim('cc-num') = false, want true")
225	}
226	if !verifiedJWT.HasBooleanClaim("cc-bool") {
227		t.Errorf("verifiedJWT.HasBooleanClaim('cc-bool') = false, want true")
228	}
229	if !verifiedJWT.HasStringClaim("cc-string") {
230		t.Errorf("verifiedJWT.HasStringClaim('cc-string') = false, want true")
231	}
232	if !verifiedJWT.HasArrayClaim("cc-array") {
233		t.Errorf("verifiedJWT.HasArrayClaim('cc-array') = false, want true")
234	}
235	if !verifiedJWT.HasObjectClaim("cc-object") {
236		t.Errorf("verifiedJWT.HasObjectClaim('cc-object') = false, want true")
237	}
238	number, err := verifiedJWT.NumberClaim("cc-num")
239	if err != nil {
240		t.Errorf("verifiedJWT.NumberClaim('cc-num') err = %v, want nil", err)
241	}
242	if !cmp.Equal(number, opts.CustomClaims["cc-num"]) {
243		t.Errorf("verifiedJWT.NumberClaim('cc-num') = %f, want %f", number, opts.CustomClaims["cc-num"])
244	}
245	boolean, err := verifiedJWT.BooleanClaim("cc-bool")
246	if err != nil {
247		t.Errorf("verifiedJWT.BooleanClaim('cc-bool') err = %v, want nil", err)
248	}
249	if !cmp.Equal(boolean, opts.CustomClaims["cc-bool"]) {
250		t.Errorf("verifiedJWT.BooleanClaim('cc-bool') = %v, want %v", boolean, opts.CustomClaims["cc-bool"])
251	}
252	str, err := verifiedJWT.StringClaim("cc-string")
253	if err != nil {
254		t.Errorf("verifiedJWT.StringClaim('cc-string') err = %v, want nil", err)
255	}
256	if !cmp.Equal(str, opts.CustomClaims["cc-string"]) {
257		t.Errorf("verifiedJWT.StringClaim('cc-string') = %q, want %q", str, opts.CustomClaims["cc-string"])
258	}
259	array, err := verifiedJWT.ArrayClaim("cc-array")
260	if err != nil {
261		t.Errorf("verifiedJWT.ArrayClaim('cc-array') err = %v, want nil", err)
262	}
263	if !cmp.Equal(array, opts.CustomClaims["cc-array"]) {
264		t.Errorf("verifiedJWT.ArrayClaim('cc-array') = %q, want %q", array, opts.CustomClaims["cc-array"])
265	}
266	object, err := verifiedJWT.ObjectClaim("cc-object")
267	if err != nil {
268		t.Errorf("verifiedJWT.ObjectClaim('cc-object') err = %v, want nil", err)
269	}
270	if !cmp.Equal(object, opts.CustomClaims["cc-object"]) {
271		t.Errorf("verifiedJWT.ObjectClaim('cc-object') = %q, want %q", object, opts.CustomClaims["cc-object"])
272	}
273}
274
275func TestCustomClaimIsFalseForWrongType(t *testing.T) {
276	opts := &jwt.RawJWTOptions{
277		WithoutExpiration: true,
278		CustomClaims: map[string]interface{}{
279			"cc-null":   nil,
280			"cc-num":    1.67,
281			"cc-bool":   true,
282			"cc-string": "goo",
283			"cc-array":  []interface{}{"1", "2", "3"},
284			"cc-object": map[string]interface{}{"cc-nested-num": 5.99},
285		},
286	}
287	rawJWT, err := jwt.NewRawJWT(opts)
288	if err != nil {
289		t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err)
290	}
291	verifiedJWT, err := createVerifiedJWT(rawJWT)
292	if err != nil {
293		t.Fatalf("creating verifiedJWT: %v", err)
294	}
295	if verifiedJWT.HasNullClaim("cc-object") {
296		t.Errorf("verifiedJWT.HasNullClaim('cc-object') = true, want false")
297	}
298	if verifiedJWT.HasNumberClaim("cc-bool") {
299		t.Errorf("verifiedJWT.HasNumberClaim('cc-bool') = true, want false")
300	}
301	if verifiedJWT.HasStringClaim("cc-array") {
302		t.Errorf("verifiedJWT.HasStringClaim('cc-array') = true, want false")
303	}
304	if verifiedJWT.HasBooleanClaim("cc-string") {
305		t.Errorf("verifiedJWT.HasBooleanClaim('cc-string') = true, want false")
306	}
307	if verifiedJWT.HasArrayClaim("cc-null") {
308		t.Errorf("verifiedJWT.HasArrayClaim('cc-null') = true, want false")
309	}
310	if verifiedJWT.HasObjectClaim("cc-num") {
311		t.Errorf("verifiedJWT.HasObjectClaim('cc-num') = true, want false")
312	}
313}
314
315func TestNoClaimsCallHasAndGet(t *testing.T) {
316	opts := &jwt.RawJWTOptions{
317		WithoutExpiration: true,
318	}
319	rawJWT, err := jwt.NewRawJWT(opts)
320	if err != nil {
321		t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err)
322	}
323	verifiedJWT, err := createVerifiedJWT(rawJWT)
324	if err != nil {
325		t.Fatalf("creating verifiedJWT: %v", err)
326	}
327	if verifiedJWT.HasAudiences() {
328		t.Errorf("verifiedJWT.HasAudiences() = true, want false")
329	}
330	if verifiedJWT.HasSubject() {
331		t.Errorf("verifiedJWT.HasSubject() = true, want false")
332	}
333	if verifiedJWT.HasIssuer() {
334		t.Errorf("verifiedJWT.HasIssuer() = true, want false")
335	}
336	if verifiedJWT.HasJWTID() {
337		t.Errorf("verifiedJWT.HasJWTID() = true, want false")
338	}
339	if verifiedJWT.HasNotBefore() {
340		t.Errorf("verifiedJWT.HasNotBefore() = true, want false")
341	}
342	if verifiedJWT.HasExpiration() {
343		t.Errorf("verifiedJWT.HasExpiration() = true, want false")
344	}
345	if verifiedJWT.HasIssuedAt() {
346		t.Errorf("verifiedJWT.HasIssuedAt() = true, want false")
347	}
348	if !cmp.Equal(verifiedJWT.CustomClaimNames(), []string{}) {
349		t.Errorf("verifiedJWT.CustomClaimNames() = %q want %q", verifiedJWT.CustomClaimNames(), []string{})
350	}
351}
352
353func TestCantGetRegisteredClaimsThroughCustomClaims(t *testing.T) {
354	now := time.Now()
355	opts := &jwt.RawJWTOptions{
356		TypeHeader: refString("typeHeader"),
357		Subject:    refString("test-subject"),
358		Issuer:     refString("test-issuer"),
359		JWTID:      refString("1"),
360		Audiences:  []string{"foo", "bar"},
361		ExpiresAt:  refTime(now.Add(time.Hour * 24).Unix()),
362		IssuedAt:   refTime(now.Unix()),
363		NotBefore:  refTime(now.Add(-time.Hour * 2).Unix()),
364	}
365	rawJWT, err := jwt.NewRawJWT(opts)
366	if err != nil {
367		t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err)
368	}
369	verifiedJWT, err := createVerifiedJWT(rawJWT)
370	if err != nil {
371		t.Fatalf("creating verifiedJWT: %v", err)
372	}
373	for _, c := range []string{"iss", "sub", "aud", "jti", "exp", "nbf", "iat"} {
374		if verifiedJWT.HasStringClaim(c) {
375			t.Errorf("verifiedJWT.HasStringClaim(%q) = true, want false", c)
376		}
377		if verifiedJWT.HasNumberClaim(c) {
378			t.Errorf("verifiedJWT.HasNumberClaim(%q) = true, want false", c)
379		}
380		if verifiedJWT.HasArrayClaim(c) {
381			t.Errorf("verifiedJWT.HasArrayClaim(%q) = true, want false", c)
382		}
383
384		if _, err := verifiedJWT.StringClaim(c); err == nil {
385			t.Errorf("verifiedJWT.StringClaim(%q) err = nil, want error", c)
386		}
387		if _, err := verifiedJWT.NumberClaim(c); err == nil {
388			t.Errorf("verifiedJWT.NumberClaim(%q) err = nil, want error", c)
389		}
390		if _, err := verifiedJWT.ArrayClaim(c); err == nil {
391			t.Errorf("verifiedJWT.ArrayClaim(%q) err = nil, want error", c)
392		}
393	}
394}
395
396func TestGetJSONPayload(t *testing.T) {
397	opts := &jwt.RawJWTOptions{
398		Subject:           refString("test-subject"),
399		WithoutExpiration: true,
400	}
401	rawJWT, err := jwt.NewRawJWT(opts)
402	if err != nil {
403		t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err)
404	}
405	verifiedJWT, err := createVerifiedJWT(rawJWT)
406	if err != nil {
407		t.Fatalf("creating verifiedJWT: %v", err)
408	}
409	j, err := verifiedJWT.JSONPayload()
410	if err != nil {
411		t.Errorf("verifiedJWT.JSONPayload() err = %v, want nil", err)
412	}
413	expected := `{"sub":"test-subject"}`
414	if !cmp.Equal(string(j), expected) {
415		t.Errorf("verifiedJWT.JSONPayload() = %q, want %q", string(j), expected)
416	}
417}
418