1// Copyright 2021 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package jwt 18 19import ( 20 "fmt" 21 "time" 22 "unicode/utf8" 23 24 spb "google.golang.org/protobuf/types/known/structpb" 25) 26 27const ( 28 claimIssuer = "iss" 29 claimSubject = "sub" 30 claimAudience = "aud" 31 claimExpiration = "exp" 32 claimNotBefore = "nbf" 33 claimIssuedAt = "iat" 34 claimJWTID = "jti" 35 36 jwtTimestampMax = 253402300799 37 jwtTimestampMin = 0 38) 39 40// RawJWTOptions represent an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519. 41// 42// It contains all payload claims and a subset of the headers. It does not 43// contain any headers that depend on the key, such as "alg" or "kid", because 44// these headers are chosen when the token is signed and encoded, and should not 45// be chosen by the user. This ensures that the key can be changed without any 46// changes to the user code. 47type RawJWTOptions struct { 48 Audiences []string 49 Audience *string 50 Subject *string 51 Issuer *string 52 JWTID *string 53 IssuedAt *time.Time 54 ExpiresAt *time.Time 55 NotBefore *time.Time 56 CustomClaims map[string]interface{} 57 58 TypeHeader *string 59 WithoutExpiration bool 60} 61 62// RawJWT is an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519. 63type RawJWT struct { 64 jsonpb *spb.Struct 65 typeHeader *string 66} 67 68// NewRawJWT constructs a new RawJWT token based on the RawJwtOptions provided. 69func NewRawJWT(opts *RawJWTOptions) (*RawJWT, error) { 70 if opts == nil { 71 return nil, fmt.Errorf("jwt options can't be nil") 72 } 73 payload, err := createPayload(opts) 74 if err != nil { 75 return nil, err 76 } 77 if err := validatePayload(payload); err != nil { 78 return nil, err 79 } 80 return &RawJWT{ 81 jsonpb: payload, 82 typeHeader: opts.TypeHeader, 83 }, nil 84} 85 86// NewRawJWTFromJSON builds a RawJWT from a marshaled JSON. 87// Users shouldn't call this function and instead use NewRawJWT. 88func NewRawJWTFromJSON(typeHeader *string, jsonPayload []byte) (*RawJWT, error) { 89 payload := &spb.Struct{} 90 if err := payload.UnmarshalJSON(jsonPayload); err != nil { 91 return nil, err 92 } 93 if err := validatePayload(payload); err != nil { 94 return nil, err 95 } 96 return &RawJWT{ 97 jsonpb: payload, 98 typeHeader: typeHeader, 99 }, nil 100} 101 102// JSONPayload marshals a RawJWT payload to JSON. 103func (r *RawJWT) JSONPayload() ([]byte, error) { 104 return r.jsonpb.MarshalJSON() 105} 106 107// HasTypeHeader returns whether a RawJWT contains a type header. 108func (r *RawJWT) HasTypeHeader() bool { 109 return r.typeHeader != nil 110} 111 112// TypeHeader returns the JWT type header. 113func (r *RawJWT) TypeHeader() (string, error) { 114 if !r.HasTypeHeader() { 115 return "", fmt.Errorf("no type header present") 116 } 117 return *r.typeHeader, nil 118} 119 120// HasAudiences checks whether a JWT contains the audience claim ('aud'). 121func (r *RawJWT) HasAudiences() bool { 122 return r.hasField(claimAudience) 123} 124 125// Audiences returns a list of audiences from the 'aud' claim. If the 'aud' claim is a single string, it is converted into a list with a single entry. 126func (r *RawJWT) Audiences() ([]string, error) { 127 aud, ok := r.field(claimAudience) 128 if !ok { 129 return nil, fmt.Errorf("no audience claim found") 130 } 131 if err := validateAudienceClaim(aud); err != nil { 132 return nil, err 133 } 134 if val, isString := aud.GetKind().(*spb.Value_StringValue); isString { 135 return []string{val.StringValue}, nil 136 } 137 s := make([]string, 0, len(aud.GetListValue().GetValues())) 138 for _, a := range aud.GetListValue().GetValues() { 139 s = append(s, a.GetStringValue()) 140 } 141 return s, nil 142} 143 144// HasSubject checks whether a JWT contains an issuer claim ('sub'). 145func (r *RawJWT) HasSubject() bool { 146 return r.hasField(claimSubject) 147} 148 149// Subject returns the subject claim ('sub') or an error if no claim is present. 150func (r *RawJWT) Subject() (string, error) { 151 return r.stringClaim(claimSubject) 152} 153 154// HasIssuer checks whether a JWT contains an issuer claim ('iss'). 155func (r *RawJWT) HasIssuer() bool { 156 return r.hasField(claimIssuer) 157} 158 159// Issuer returns the issuer claim ('iss') or an error if no claim is present. 160func (r *RawJWT) Issuer() (string, error) { 161 return r.stringClaim(claimIssuer) 162} 163 164// HasJWTID checks whether a JWT contains an JWT ID claim ('jti'). 165func (r *RawJWT) HasJWTID() bool { 166 return r.hasField(claimJWTID) 167} 168 169// JWTID returns the JWT ID claim ('jti') or an error if no claim is present. 170func (r *RawJWT) JWTID() (string, error) { 171 return r.stringClaim(claimJWTID) 172} 173 174// HasIssuedAt checks whether a JWT contains an issued at claim ('iat'). 175func (r *RawJWT) HasIssuedAt() bool { 176 return r.hasField(claimIssuedAt) 177} 178 179// IssuedAt returns the issued at claim ('iat') or an error if no claim is present. 180func (r *RawJWT) IssuedAt() (time.Time, error) { 181 return r.timeClaim(claimIssuedAt) 182} 183 184// HasExpiration checks whether a JWT contains an expiration time claim ('exp'). 185func (r *RawJWT) HasExpiration() bool { 186 return r.hasField(claimExpiration) 187} 188 189// ExpiresAt returns the expiration claim ('exp') or an error if no claim is present. 190func (r *RawJWT) ExpiresAt() (time.Time, error) { 191 return r.timeClaim(claimExpiration) 192} 193 194// HasNotBefore checks whether a JWT contains a not before claim ('nbf'). 195func (r *RawJWT) HasNotBefore() bool { 196 return r.hasField(claimNotBefore) 197} 198 199// NotBefore returns the not before claim ('nbf') or an error if no claim is present. 200func (r *RawJWT) NotBefore() (time.Time, error) { 201 return r.timeClaim(claimNotBefore) 202} 203 204// HasStringClaim checks whether a claim of type string is present. 205func (r *RawJWT) HasStringClaim(name string) bool { 206 return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StringValue{}}) 207} 208 209// StringClaim returns a custom string claim or an error if no claim is present. 210func (r *RawJWT) StringClaim(name string) (string, error) { 211 if isRegisteredClaim(name) { 212 return "", fmt.Errorf("claim '%q' is a registered claim", name) 213 } 214 return r.stringClaim(name) 215} 216 217// HasNumberClaim checks whether a claim of type number is present. 218func (r *RawJWT) HasNumberClaim(name string) bool { 219 return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NumberValue{}}) 220} 221 222// NumberClaim returns a custom number claim or an error if no claim is present. 223func (r *RawJWT) NumberClaim(name string) (float64, error) { 224 if isRegisteredClaim(name) { 225 return 0, fmt.Errorf("claim '%q' is a registered claim", name) 226 } 227 return r.numberClaim(name) 228} 229 230// HasBooleanClaim checks whether a claim of type boolean is present. 231func (r *RawJWT) HasBooleanClaim(name string) bool { 232 return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_BoolValue{}}) 233} 234 235// BooleanClaim returns a custom bool claim or an error if no claim is present. 236func (r *RawJWT) BooleanClaim(name string) (bool, error) { 237 val, err := r.customClaim(name) 238 if err != nil { 239 return false, err 240 } 241 b, ok := val.Kind.(*spb.Value_BoolValue) 242 if !ok { 243 return false, fmt.Errorf("claim '%q' is not a boolean", name) 244 } 245 return b.BoolValue, nil 246} 247 248// HasNullClaim checks whether a claim of type null is present. 249func (r *RawJWT) HasNullClaim(name string) bool { 250 return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NullValue{}}) 251} 252 253// HasArrayClaim checks whether a claim of type list is present. 254func (r *RawJWT) HasArrayClaim(name string) bool { 255 return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_ListValue{}}) 256} 257 258// ArrayClaim returns a slice representing a JSON array for a claim or an error if the claim is empty. 259func (r *RawJWT) ArrayClaim(name string) ([]interface{}, error) { 260 val, err := r.customClaim(name) 261 if err != nil { 262 return nil, err 263 } 264 if val.GetListValue() == nil { 265 return nil, fmt.Errorf("claim '%q' is not a list", name) 266 } 267 return val.GetListValue().AsSlice(), nil 268} 269 270// HasObjectClaim checks whether a claim of type JSON object is present. 271func (r *RawJWT) HasObjectClaim(name string) bool { 272 return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StructValue{}}) 273} 274 275// ObjectClaim returns a map representing a JSON object for a claim or an error if the claim is empty. 276func (r *RawJWT) ObjectClaim(name string) (map[string]interface{}, error) { 277 val, err := r.customClaim(name) 278 if err != nil { 279 return nil, err 280 } 281 if val.GetStructValue() == nil { 282 return nil, fmt.Errorf("claim '%q' is not a JSON object", name) 283 } 284 return val.GetStructValue().AsMap(), err 285} 286 287// CustomClaimNames returns a list with the name of custom claims in a RawJWT. 288func (r *RawJWT) CustomClaimNames() []string { 289 names := []string{} 290 for key := range r.jsonpb.GetFields() { 291 if !isRegisteredClaim(key) { 292 names = append(names, key) 293 } 294 } 295 return names 296} 297 298func (r *RawJWT) timeClaim(name string) (time.Time, error) { 299 n, err := r.numberClaim(name) 300 if err != nil { 301 return time.Time{}, err 302 } 303 return time.Unix(int64(n), 0), err 304} 305 306func (r *RawJWT) numberClaim(name string) (float64, error) { 307 val, ok := r.field(name) 308 if !ok { 309 return 0, fmt.Errorf("no '%q' claim found", name) 310 } 311 s, ok := val.Kind.(*spb.Value_NumberValue) 312 if !ok { 313 return 0, fmt.Errorf("claim '%q' is not a number", name) 314 } 315 return s.NumberValue, nil 316} 317 318func (r *RawJWT) stringClaim(name string) (string, error) { 319 val, ok := r.field(name) 320 if !ok { 321 return "", fmt.Errorf("no '%q' claim found", name) 322 } 323 s, ok := val.Kind.(*spb.Value_StringValue) 324 if !ok { 325 return "", fmt.Errorf("claim '%q' is not a string", name) 326 } 327 if !utf8.ValidString(s.StringValue) { 328 return "", fmt.Errorf("claim '%q' is not a valid utf-8 encoded string", name) 329 } 330 return s.StringValue, nil 331} 332 333func (r *RawJWT) hasClaimOfKind(name string, exp *spb.Value) bool { 334 val, exist := r.field(name) 335 if !exist || exp == nil { 336 return false 337 } 338 var isKind bool 339 switch exp.GetKind().(type) { 340 case *spb.Value_StructValue: 341 _, isKind = val.GetKind().(*spb.Value_StructValue) 342 case *spb.Value_NullValue: 343 _, isKind = val.GetKind().(*spb.Value_NullValue) 344 case *spb.Value_BoolValue: 345 _, isKind = val.GetKind().(*spb.Value_BoolValue) 346 case *spb.Value_ListValue: 347 _, isKind = val.GetKind().(*spb.Value_ListValue) 348 case *spb.Value_StringValue: 349 _, isKind = val.GetKind().(*spb.Value_StringValue) 350 case *spb.Value_NumberValue: 351 _, isKind = val.GetKind().(*spb.Value_NumberValue) 352 default: 353 isKind = false 354 } 355 return isKind 356} 357 358func (r *RawJWT) customClaim(name string) (*spb.Value, error) { 359 if isRegisteredClaim(name) { 360 return nil, fmt.Errorf("'%q' is a registered claim", name) 361 } 362 val, ok := r.field(name) 363 if !ok { 364 return nil, fmt.Errorf("claim '%q' not found", name) 365 } 366 return val, nil 367} 368 369func (r *RawJWT) hasField(name string) bool { 370 _, ok := r.field(name) 371 return ok 372} 373 374func (r *RawJWT) field(name string) (*spb.Value, bool) { 375 val, ok := r.jsonpb.GetFields()[name] 376 return val, ok 377} 378 379// createPayload creates a JSON payload from JWT options. 380func createPayload(opts *RawJWTOptions) (*spb.Struct, error) { 381 if err := validateCustomClaims(opts.CustomClaims); err != nil { 382 return nil, err 383 } 384 if opts.ExpiresAt == nil && !opts.WithoutExpiration { 385 return nil, fmt.Errorf("jwt options must contain an expiration or must be marked WithoutExpiration") 386 } 387 if opts.ExpiresAt != nil && opts.WithoutExpiration { 388 return nil, fmt.Errorf("jwt options can't be marked WithoutExpiration when expiration is specified") 389 } 390 if opts.Audience != nil && opts.Audiences != nil { 391 return nil, fmt.Errorf("jwt options can either contain a single Audience or a list of Audiences but not both") 392 } 393 394 payload := &spb.Struct{ 395 Fields: map[string]*spb.Value{}, 396 } 397 setStringValue(payload, claimJWTID, opts.JWTID) 398 setStringValue(payload, claimIssuer, opts.Issuer) 399 setStringValue(payload, claimSubject, opts.Subject) 400 setStringValue(payload, claimAudience, opts.Audience) 401 setTimeValue(payload, claimIssuedAt, opts.IssuedAt) 402 setTimeValue(payload, claimNotBefore, opts.NotBefore) 403 setTimeValue(payload, claimExpiration, opts.ExpiresAt) 404 setAudiences(payload, claimAudience, opts.Audiences) 405 406 for k, v := range opts.CustomClaims { 407 val, err := spb.NewValue(v) 408 if err != nil { 409 return nil, err 410 } 411 setValue(payload, k, val) 412 } 413 return payload, nil 414} 415 416func validatePayload(payload *spb.Struct) error { 417 if payload.Fields == nil || len(payload.Fields) == 0 { 418 return nil 419 } 420 if err := validateAudienceClaim(payload.Fields[claimAudience]); err != nil { 421 return err 422 } 423 for claim, val := range payload.GetFields() { 424 if isRegisteredTimeClaim(claim) { 425 if err := validateTimeClaim(claim, val); err != nil { 426 return err 427 } 428 } 429 430 if isRegisteredStringClaim(claim) { 431 if err := validateStringClaim(claim, val); err != nil { 432 return err 433 } 434 } 435 } 436 return nil 437} 438 439func validateStringClaim(claim string, val *spb.Value) error { 440 v, ok := val.Kind.(*spb.Value_StringValue) 441 if !ok { 442 return fmt.Errorf("claim: '%q' MUST be a string", claim) 443 } 444 if !utf8.ValidString(v.StringValue) { 445 return fmt.Errorf("claim: '%q' isn't a valid UTF-8 string", claim) 446 } 447 return nil 448} 449 450func validateTimeClaim(claim string, val *spb.Value) error { 451 if _, ok := val.Kind.(*spb.Value_NumberValue); !ok { 452 return fmt.Errorf("claim %q MUST be a numeric value, ", claim) 453 } 454 t := int64(val.GetNumberValue()) 455 if t > jwtTimestampMax || t < jwtTimestampMin { 456 return fmt.Errorf("invalid timestamp: '%d' for claim: %q", t, claim) 457 } 458 return nil 459} 460 461func validateAudienceClaim(val *spb.Value) error { 462 if val == nil { 463 return nil 464 } 465 _, isString := val.Kind.(*spb.Value_StringValue) 466 l, isList := val.Kind.(*spb.Value_ListValue) 467 if !isList && !isString { 468 return fmt.Errorf("audience claim MUST be a list with at least one string or a single string value") 469 } 470 if isString { 471 return validateStringClaim(claimAudience, val) 472 } 473 if l.ListValue != nil && len(l.ListValue.Values) == 0 { 474 return fmt.Errorf("there MUST be at least one value present in the audience claim") 475 } 476 for _, aud := range l.ListValue.Values { 477 v, ok := aud.Kind.(*spb.Value_StringValue) 478 if !ok { 479 return fmt.Errorf("audience value is not a string") 480 } 481 if !utf8.ValidString(v.StringValue) { 482 return fmt.Errorf("audience value is not a valid UTF-8 string") 483 } 484 } 485 return nil 486} 487 488func validateCustomClaims(cc map[string]interface{}) error { 489 if cc == nil { 490 return nil 491 } 492 for key := range cc { 493 if isRegisteredClaim(key) { 494 return fmt.Errorf("claim '%q' is a registered claim, it can't be declared as a custom claim", key) 495 } 496 } 497 return nil 498} 499 500func setTimeValue(p *spb.Struct, claim string, val *time.Time) { 501 if val == nil { 502 return 503 } 504 setValue(p, claim, spb.NewNumberValue(float64(val.Unix()))) 505} 506 507func setStringValue(p *spb.Struct, claim string, val *string) { 508 if val == nil { 509 return 510 } 511 setValue(p, claim, spb.NewStringValue(*val)) 512} 513 514func setAudiences(p *spb.Struct, claim string, vals []string) { 515 if vals == nil { 516 return 517 } 518 audList := &spb.ListValue{ 519 Values: make([]*spb.Value, 0, len(vals)), 520 } 521 for _, aud := range vals { 522 audList.Values = append(audList.Values, spb.NewStringValue(aud)) 523 } 524 setValue(p, claim, spb.NewListValue(audList)) 525} 526 527func setValue(p *spb.Struct, claim string, val *spb.Value) { 528 if p.GetFields() == nil { 529 p.Fields = make(map[string]*spb.Value) 530 } 531 p.GetFields()[claim] = val 532} 533 534func isRegisteredClaim(c string) bool { 535 return isRegisteredStringClaim(c) || isRegisteredTimeClaim(c) || c == claimAudience 536} 537 538func isRegisteredStringClaim(c string) bool { 539 return c == claimIssuer || c == claimSubject || c == claimJWTID 540} 541 542func isRegisteredTimeClaim(c string) bool { 543 return c == claimExpiration || c == claimNotBefore || c == claimIssuedAt 544} 545