1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package jwt_test 18 19import ( 20 "testing" 21 "time" 22 23 "github.com/google/go-cmp/cmp" 24 "github.com/google/go-cmp/cmp/cmpopts" 25 "github.com/google/tink/go/jwt" 26 "github.com/google/tink/go/keyset" 27) 28 29func createVerifiedJWT(rawJWT *jwt.RawJWT) (*jwt.VerifiedJWT, error) { 30 kh, err := keyset.NewHandle(jwt.HS256Template()) 31 if err != nil { 32 return nil, err 33 } 34 m, err := jwt.NewMAC(kh) 35 if err != nil { 36 return nil, err 37 } 38 compact, err := m.ComputeMACAndEncode(rawJWT) 39 if err != nil { 40 return nil, err 41 } 42 // This validator is purposely instantiated to always pass. 43 // It isn't really validating much and probably shouldn't 44 // be used like this out side of these tests. 45 opts := &jwt.ValidatorOpts{ 46 AllowMissingExpiration: true, 47 IgnoreTypeHeader: true, 48 IgnoreAudiences: true, 49 IgnoreIssuer: true, 50 } 51 issuedAt, err := rawJWT.IssuedAt() 52 if err == nil { 53 opts.FixedNow = issuedAt 54 } 55 56 validator, err := jwt.NewValidator(opts) 57 if err != nil { 58 return nil, err 59 } 60 return m.VerifyMACAndDecode(compact, validator) 61} 62 63func TestGetRegisteredStringClaims(t *testing.T) { 64 opts := &jwt.RawJWTOptions{ 65 TypeHeader: refString("typeHeader"), 66 Subject: refString("test-subject"), 67 Issuer: refString("test-issuer"), 68 JWTID: refString("1"), 69 WithoutExpiration: true, 70 } 71 rawJWT, err := jwt.NewRawJWT(opts) 72 if err != nil { 73 t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err) 74 } 75 verifiedJWT, err := createVerifiedJWT(rawJWT) 76 if err != nil { 77 t.Fatalf("creating verifiedJWT: %v", err) 78 } 79 if !verifiedJWT.HasTypeHeader() { 80 t.Errorf("verifiedJWT.HasTypeHeader() = false, want true") 81 } 82 if !verifiedJWT.HasSubject() { 83 t.Errorf("verifiedJWT.HasSubject() = false, want true") 84 } 85 if !verifiedJWT.HasIssuer() { 86 t.Errorf("verifiedJWT.HasIssuer() = false, want true") 87 } 88 if !verifiedJWT.HasJWTID() { 89 t.Errorf("verifiedJWT.HasJWTID() = false, want true") 90 } 91 typeHeader, err := verifiedJWT.TypeHeader() 92 if err != nil { 93 t.Errorf("verifiedJWT.TypeHeader() err = %v, want nil", err) 94 } 95 if !cmp.Equal(typeHeader, *opts.TypeHeader) { 96 t.Errorf("verifiedJWT.TypeHeader() = %q, want %q", typeHeader, *opts.TypeHeader) 97 } 98 subject, err := verifiedJWT.Subject() 99 if err != nil { 100 t.Errorf("verifiedJWT.Subject() err = %v, want nil", err) 101 } 102 if !cmp.Equal(subject, *opts.Subject) { 103 t.Errorf("verifiedJWT.Subject() = %q, want %q", subject, *opts.Subject) 104 } 105 issuer, err := verifiedJWT.Issuer() 106 if err != nil { 107 t.Errorf("verifiedJWT.Issuer() err = %v, want nil", err) 108 } 109 if !cmp.Equal(issuer, *opts.Issuer) { 110 t.Errorf("verifiedJWT.Issuer() = %q, want %q", issuer, *opts.Issuer) 111 } 112 jwtID, err := verifiedJWT.JWTID() 113 if err != nil { 114 t.Errorf("verifiedJWT.JWTID() err = %v, want nil", err) 115 } 116 if !cmp.Equal(jwtID, *opts.JWTID) { 117 t.Errorf("verifiedJWT.JWTID() = %q, want %q", jwtID, *opts.JWTID) 118 } 119 if !cmp.Equal(verifiedJWT.CustomClaimNames(), []string{}) { 120 t.Errorf("verifiedJWT.CustomClaimNames() = %q want %q", verifiedJWT.CustomClaimNames(), []string{}) 121 } 122} 123 124func TestGetRegisteredTimestampClaims(t *testing.T) { 125 now := time.Now() 126 opts := &jwt.RawJWTOptions{ 127 ExpiresAt: refTime(now.Add(time.Hour * 24).Unix()), 128 IssuedAt: refTime(now.Unix()), 129 NotBefore: refTime(now.Add(-time.Hour * 2).Unix()), 130 } 131 rawJWT, err := jwt.NewRawJWT(opts) 132 if err != nil { 133 t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err) 134 } 135 verifiedJWT, err := createVerifiedJWT(rawJWT) 136 if err != nil { 137 t.Fatalf("creating verifiedJWT: %v", err) 138 } 139 if !verifiedJWT.HasExpiration() { 140 t.Errorf("verifiedJWT.HasExpiration() = false, want true") 141 } 142 if !verifiedJWT.HasIssuedAt() { 143 t.Errorf("verifiedJWT.HasIssuedAt() = false, want true") 144 } 145 if !verifiedJWT.HasNotBefore() { 146 t.Errorf("verifiedJWT.HasNotBefore() = false, want true") 147 } 148 expiresAt, err := verifiedJWT.ExpiresAt() 149 if err != nil { 150 t.Errorf("verifiedJWT.ExpiresAt() err = %v, want nil", err) 151 } 152 if !cmp.Equal(expiresAt, *opts.ExpiresAt) { 153 t.Errorf("verifiedJWT.ExpiresAt() = %q, want %q", expiresAt, *opts.ExpiresAt) 154 } 155 issuedAt, err := verifiedJWT.IssuedAt() 156 if err != nil { 157 t.Errorf("verifiedJWT.IssuedAt() err = %v, want nil", err) 158 } 159 if !cmp.Equal(issuedAt, *opts.IssuedAt) { 160 t.Errorf("verifiedJWT.IssuedAt() = %q, want %q", issuedAt, *opts.IssuedAt) 161 } 162 notBefore, err := verifiedJWT.NotBefore() 163 if err != nil { 164 t.Errorf("verifiedJWT.NotBefore() err = %v, want nil", err) 165 } 166 if !cmp.Equal(notBefore, *opts.NotBefore) { 167 t.Errorf("verifiedJWT.NotBefore() = %q, want %q", notBefore, *opts.NotBefore) 168 } 169} 170 171func TestGetAudiencesClaim(t *testing.T) { 172 opts := &jwt.RawJWTOptions{ 173 WithoutExpiration: true, 174 Audiences: []string{"foo", "bar"}, 175 } 176 rawJWT, err := jwt.NewRawJWT(opts) 177 if err != nil { 178 t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err) 179 } 180 verifiedJWT, err := createVerifiedJWT(rawJWT) 181 if err != nil { 182 t.Fatalf("creating verifiedJWT: %v", err) 183 } 184 if !verifiedJWT.HasAudiences() { 185 t.Errorf("verifiedJWT.HasAudiences() = false, want true") 186 } 187 audiences, err := verifiedJWT.Audiences() 188 if err != nil { 189 t.Errorf("verifiedJWT.Audiences() err = %v, want nil", err) 190 } 191 if !cmp.Equal(audiences, opts.Audiences) { 192 t.Errorf("verifiedJWT.Audiences() = %q, want %q", audiences, opts.Audiences) 193 } 194} 195 196func TestGetCustomClaims(t *testing.T) { 197 opts := &jwt.RawJWTOptions{ 198 WithoutExpiration: true, 199 CustomClaims: map[string]interface{}{ 200 "cc-null": nil, 201 "cc-num": 1.67, 202 "cc-bool": true, 203 "cc-string": "goo", 204 "cc-array": []interface{}{"1", "2", "3"}, 205 "cc-object": map[string]interface{}{"cc-nested-num": 5.99}, 206 }, 207 } 208 rawJWT, err := jwt.NewRawJWT(opts) 209 if err != nil { 210 t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err) 211 } 212 verifiedJWT, err := createVerifiedJWT(rawJWT) 213 if err != nil { 214 t.Fatalf("creating verifiedJWT: %v", err) 215 } 216 wantCustomClaims := []string{"cc-num", "cc-bool", "cc-null", "cc-string", "cc-array", "cc-object"} 217 if !cmp.Equal(verifiedJWT.CustomClaimNames(), wantCustomClaims, cmpopts.SortSlices(func(a, b string) bool { return a < b })) { 218 t.Errorf("verifiedJWT.CustomClaimNames() = %q, want %q", verifiedJWT.CustomClaimNames(), wantCustomClaims) 219 } 220 if !verifiedJWT.HasNullClaim("cc-null") { 221 t.Errorf("verifiedJWT.HasNullClaim('cc-null') = false, want true") 222 } 223 if !verifiedJWT.HasNumberClaim("cc-num") { 224 t.Errorf("verifiedJWT.HasNumberClaim('cc-num') = false, want true") 225 } 226 if !verifiedJWT.HasBooleanClaim("cc-bool") { 227 t.Errorf("verifiedJWT.HasBooleanClaim('cc-bool') = false, want true") 228 } 229 if !verifiedJWT.HasStringClaim("cc-string") { 230 t.Errorf("verifiedJWT.HasStringClaim('cc-string') = false, want true") 231 } 232 if !verifiedJWT.HasArrayClaim("cc-array") { 233 t.Errorf("verifiedJWT.HasArrayClaim('cc-array') = false, want true") 234 } 235 if !verifiedJWT.HasObjectClaim("cc-object") { 236 t.Errorf("verifiedJWT.HasObjectClaim('cc-object') = false, want true") 237 } 238 number, err := verifiedJWT.NumberClaim("cc-num") 239 if err != nil { 240 t.Errorf("verifiedJWT.NumberClaim('cc-num') err = %v, want nil", err) 241 } 242 if !cmp.Equal(number, opts.CustomClaims["cc-num"]) { 243 t.Errorf("verifiedJWT.NumberClaim('cc-num') = %f, want %f", number, opts.CustomClaims["cc-num"]) 244 } 245 boolean, err := verifiedJWT.BooleanClaim("cc-bool") 246 if err != nil { 247 t.Errorf("verifiedJWT.BooleanClaim('cc-bool') err = %v, want nil", err) 248 } 249 if !cmp.Equal(boolean, opts.CustomClaims["cc-bool"]) { 250 t.Errorf("verifiedJWT.BooleanClaim('cc-bool') = %v, want %v", boolean, opts.CustomClaims["cc-bool"]) 251 } 252 str, err := verifiedJWT.StringClaim("cc-string") 253 if err != nil { 254 t.Errorf("verifiedJWT.StringClaim('cc-string') err = %v, want nil", err) 255 } 256 if !cmp.Equal(str, opts.CustomClaims["cc-string"]) { 257 t.Errorf("verifiedJWT.StringClaim('cc-string') = %q, want %q", str, opts.CustomClaims["cc-string"]) 258 } 259 array, err := verifiedJWT.ArrayClaim("cc-array") 260 if err != nil { 261 t.Errorf("verifiedJWT.ArrayClaim('cc-array') err = %v, want nil", err) 262 } 263 if !cmp.Equal(array, opts.CustomClaims["cc-array"]) { 264 t.Errorf("verifiedJWT.ArrayClaim('cc-array') = %q, want %q", array, opts.CustomClaims["cc-array"]) 265 } 266 object, err := verifiedJWT.ObjectClaim("cc-object") 267 if err != nil { 268 t.Errorf("verifiedJWT.ObjectClaim('cc-object') err = %v, want nil", err) 269 } 270 if !cmp.Equal(object, opts.CustomClaims["cc-object"]) { 271 t.Errorf("verifiedJWT.ObjectClaim('cc-object') = %q, want %q", object, opts.CustomClaims["cc-object"]) 272 } 273} 274 275func TestCustomClaimIsFalseForWrongType(t *testing.T) { 276 opts := &jwt.RawJWTOptions{ 277 WithoutExpiration: true, 278 CustomClaims: map[string]interface{}{ 279 "cc-null": nil, 280 "cc-num": 1.67, 281 "cc-bool": true, 282 "cc-string": "goo", 283 "cc-array": []interface{}{"1", "2", "3"}, 284 "cc-object": map[string]interface{}{"cc-nested-num": 5.99}, 285 }, 286 } 287 rawJWT, err := jwt.NewRawJWT(opts) 288 if err != nil { 289 t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err) 290 } 291 verifiedJWT, err := createVerifiedJWT(rawJWT) 292 if err != nil { 293 t.Fatalf("creating verifiedJWT: %v", err) 294 } 295 if verifiedJWT.HasNullClaim("cc-object") { 296 t.Errorf("verifiedJWT.HasNullClaim('cc-object') = true, want false") 297 } 298 if verifiedJWT.HasNumberClaim("cc-bool") { 299 t.Errorf("verifiedJWT.HasNumberClaim('cc-bool') = true, want false") 300 } 301 if verifiedJWT.HasStringClaim("cc-array") { 302 t.Errorf("verifiedJWT.HasStringClaim('cc-array') = true, want false") 303 } 304 if verifiedJWT.HasBooleanClaim("cc-string") { 305 t.Errorf("verifiedJWT.HasBooleanClaim('cc-string') = true, want false") 306 } 307 if verifiedJWT.HasArrayClaim("cc-null") { 308 t.Errorf("verifiedJWT.HasArrayClaim('cc-null') = true, want false") 309 } 310 if verifiedJWT.HasObjectClaim("cc-num") { 311 t.Errorf("verifiedJWT.HasObjectClaim('cc-num') = true, want false") 312 } 313} 314 315func TestNoClaimsCallHasAndGet(t *testing.T) { 316 opts := &jwt.RawJWTOptions{ 317 WithoutExpiration: true, 318 } 319 rawJWT, err := jwt.NewRawJWT(opts) 320 if err != nil { 321 t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err) 322 } 323 verifiedJWT, err := createVerifiedJWT(rawJWT) 324 if err != nil { 325 t.Fatalf("creating verifiedJWT: %v", err) 326 } 327 if verifiedJWT.HasAudiences() { 328 t.Errorf("verifiedJWT.HasAudiences() = true, want false") 329 } 330 if verifiedJWT.HasSubject() { 331 t.Errorf("verifiedJWT.HasSubject() = true, want false") 332 } 333 if verifiedJWT.HasIssuer() { 334 t.Errorf("verifiedJWT.HasIssuer() = true, want false") 335 } 336 if verifiedJWT.HasJWTID() { 337 t.Errorf("verifiedJWT.HasJWTID() = true, want false") 338 } 339 if verifiedJWT.HasNotBefore() { 340 t.Errorf("verifiedJWT.HasNotBefore() = true, want false") 341 } 342 if verifiedJWT.HasExpiration() { 343 t.Errorf("verifiedJWT.HasExpiration() = true, want false") 344 } 345 if verifiedJWT.HasIssuedAt() { 346 t.Errorf("verifiedJWT.HasIssuedAt() = true, want false") 347 } 348 if !cmp.Equal(verifiedJWT.CustomClaimNames(), []string{}) { 349 t.Errorf("verifiedJWT.CustomClaimNames() = %q want %q", verifiedJWT.CustomClaimNames(), []string{}) 350 } 351} 352 353func TestCantGetRegisteredClaimsThroughCustomClaims(t *testing.T) { 354 now := time.Now() 355 opts := &jwt.RawJWTOptions{ 356 TypeHeader: refString("typeHeader"), 357 Subject: refString("test-subject"), 358 Issuer: refString("test-issuer"), 359 JWTID: refString("1"), 360 Audiences: []string{"foo", "bar"}, 361 ExpiresAt: refTime(now.Add(time.Hour * 24).Unix()), 362 IssuedAt: refTime(now.Unix()), 363 NotBefore: refTime(now.Add(-time.Hour * 2).Unix()), 364 } 365 rawJWT, err := jwt.NewRawJWT(opts) 366 if err != nil { 367 t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err) 368 } 369 verifiedJWT, err := createVerifiedJWT(rawJWT) 370 if err != nil { 371 t.Fatalf("creating verifiedJWT: %v", err) 372 } 373 for _, c := range []string{"iss", "sub", "aud", "jti", "exp", "nbf", "iat"} { 374 if verifiedJWT.HasStringClaim(c) { 375 t.Errorf("verifiedJWT.HasStringClaim(%q) = true, want false", c) 376 } 377 if verifiedJWT.HasNumberClaim(c) { 378 t.Errorf("verifiedJWT.HasNumberClaim(%q) = true, want false", c) 379 } 380 if verifiedJWT.HasArrayClaim(c) { 381 t.Errorf("verifiedJWT.HasArrayClaim(%q) = true, want false", c) 382 } 383 384 if _, err := verifiedJWT.StringClaim(c); err == nil { 385 t.Errorf("verifiedJWT.StringClaim(%q) err = nil, want error", c) 386 } 387 if _, err := verifiedJWT.NumberClaim(c); err == nil { 388 t.Errorf("verifiedJWT.NumberClaim(%q) err = nil, want error", c) 389 } 390 if _, err := verifiedJWT.ArrayClaim(c); err == nil { 391 t.Errorf("verifiedJWT.ArrayClaim(%q) err = nil, want error", c) 392 } 393 } 394} 395 396func TestGetJSONPayload(t *testing.T) { 397 opts := &jwt.RawJWTOptions{ 398 Subject: refString("test-subject"), 399 WithoutExpiration: true, 400 } 401 rawJWT, err := jwt.NewRawJWT(opts) 402 if err != nil { 403 t.Fatalf("jwt.NewRawJWT(%v): %v", opts, err) 404 } 405 verifiedJWT, err := createVerifiedJWT(rawJWT) 406 if err != nil { 407 t.Fatalf("creating verifiedJWT: %v", err) 408 } 409 j, err := verifiedJWT.JSONPayload() 410 if err != nil { 411 t.Errorf("verifiedJWT.JSONPayload() err = %v, want nil", err) 412 } 413 expected := `{"sub":"test-subject"}` 414 if !cmp.Equal(string(j), expected) { 415 t.Errorf("verifiedJWT.JSONPayload() = %q, want %q", string(j), expected) 416 } 417} 418