xref: /aosp_15_r20/external/tink/go/jwt/raw_jwt.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package jwt
18
19import (
20	"fmt"
21	"time"
22	"unicode/utf8"
23
24	spb "google.golang.org/protobuf/types/known/structpb"
25)
26
27const (
28	claimIssuer     = "iss"
29	claimSubject    = "sub"
30	claimAudience   = "aud"
31	claimExpiration = "exp"
32	claimNotBefore  = "nbf"
33	claimIssuedAt   = "iat"
34	claimJWTID      = "jti"
35
36	jwtTimestampMax = 253402300799
37	jwtTimestampMin = 0
38)
39
40// RawJWTOptions represent an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519.
41//
42// It contains all payload claims and a subset of the headers. It does not
43// contain any headers that depend on the key, such as "alg" or "kid", because
44// these headers are chosen when the token is signed and encoded, and should not
45// be chosen by the user. This ensures that the key can be changed without any
46// changes to the user code.
47type RawJWTOptions struct {
48	Audiences    []string
49	Audience     *string
50	Subject      *string
51	Issuer       *string
52	JWTID        *string
53	IssuedAt     *time.Time
54	ExpiresAt    *time.Time
55	NotBefore    *time.Time
56	CustomClaims map[string]interface{}
57
58	TypeHeader        *string
59	WithoutExpiration bool
60}
61
62// RawJWT is an unsigned JSON Web Token (JWT), https://tools.ietf.org/html/rfc7519.
63type RawJWT struct {
64	jsonpb     *spb.Struct
65	typeHeader *string
66}
67
68// NewRawJWT constructs a new RawJWT token based on the RawJwtOptions provided.
69func NewRawJWT(opts *RawJWTOptions) (*RawJWT, error) {
70	if opts == nil {
71		return nil, fmt.Errorf("jwt options can't be nil")
72	}
73	payload, err := createPayload(opts)
74	if err != nil {
75		return nil, err
76	}
77	if err := validatePayload(payload); err != nil {
78		return nil, err
79	}
80	return &RawJWT{
81		jsonpb:     payload,
82		typeHeader: opts.TypeHeader,
83	}, nil
84}
85
86// NewRawJWTFromJSON builds a RawJWT from a marshaled JSON.
87// Users shouldn't call this function and instead use NewRawJWT.
88func NewRawJWTFromJSON(typeHeader *string, jsonPayload []byte) (*RawJWT, error) {
89	payload := &spb.Struct{}
90	if err := payload.UnmarshalJSON(jsonPayload); err != nil {
91		return nil, err
92	}
93	if err := validatePayload(payload); err != nil {
94		return nil, err
95	}
96	return &RawJWT{
97		jsonpb:     payload,
98		typeHeader: typeHeader,
99	}, nil
100}
101
102// JSONPayload marshals a RawJWT payload to JSON.
103func (r *RawJWT) JSONPayload() ([]byte, error) {
104	return r.jsonpb.MarshalJSON()
105}
106
107// HasTypeHeader returns whether a RawJWT contains a type header.
108func (r *RawJWT) HasTypeHeader() bool {
109	return r.typeHeader != nil
110}
111
112// TypeHeader returns the JWT type header.
113func (r *RawJWT) TypeHeader() (string, error) {
114	if !r.HasTypeHeader() {
115		return "", fmt.Errorf("no type header present")
116	}
117	return *r.typeHeader, nil
118}
119
120// HasAudiences checks whether a JWT contains the audience claim ('aud').
121func (r *RawJWT) HasAudiences() bool {
122	return r.hasField(claimAudience)
123}
124
125// Audiences returns a list of audiences from the 'aud' claim. If the 'aud' claim is a single string, it is converted into a list with a single entry.
126func (r *RawJWT) Audiences() ([]string, error) {
127	aud, ok := r.field(claimAudience)
128	if !ok {
129		return nil, fmt.Errorf("no audience claim found")
130	}
131	if err := validateAudienceClaim(aud); err != nil {
132		return nil, err
133	}
134	if val, isString := aud.GetKind().(*spb.Value_StringValue); isString {
135		return []string{val.StringValue}, nil
136	}
137	s := make([]string, 0, len(aud.GetListValue().GetValues()))
138	for _, a := range aud.GetListValue().GetValues() {
139		s = append(s, a.GetStringValue())
140	}
141	return s, nil
142}
143
144// HasSubject checks whether a JWT contains an issuer claim ('sub').
145func (r *RawJWT) HasSubject() bool {
146	return r.hasField(claimSubject)
147}
148
149// Subject returns the subject claim ('sub') or an error if no claim is present.
150func (r *RawJWT) Subject() (string, error) {
151	return r.stringClaim(claimSubject)
152}
153
154// HasIssuer checks whether a JWT contains an issuer claim ('iss').
155func (r *RawJWT) HasIssuer() bool {
156	return r.hasField(claimIssuer)
157}
158
159// Issuer returns the issuer claim ('iss') or an error if no claim is present.
160func (r *RawJWT) Issuer() (string, error) {
161	return r.stringClaim(claimIssuer)
162}
163
164// HasJWTID checks whether a JWT contains an JWT ID claim ('jti').
165func (r *RawJWT) HasJWTID() bool {
166	return r.hasField(claimJWTID)
167}
168
169// JWTID returns the JWT ID claim ('jti') or an error if no claim is present.
170func (r *RawJWT) JWTID() (string, error) {
171	return r.stringClaim(claimJWTID)
172}
173
174// HasIssuedAt checks whether a JWT contains an issued at claim ('iat').
175func (r *RawJWT) HasIssuedAt() bool {
176	return r.hasField(claimIssuedAt)
177}
178
179// IssuedAt returns the issued at claim ('iat') or an error if no claim is present.
180func (r *RawJWT) IssuedAt() (time.Time, error) {
181	return r.timeClaim(claimIssuedAt)
182}
183
184// HasExpiration checks whether a JWT contains an expiration time claim ('exp').
185func (r *RawJWT) HasExpiration() bool {
186	return r.hasField(claimExpiration)
187}
188
189// ExpiresAt returns the expiration claim ('exp') or an error if no claim is present.
190func (r *RawJWT) ExpiresAt() (time.Time, error) {
191	return r.timeClaim(claimExpiration)
192}
193
194// HasNotBefore checks whether a JWT contains a not before claim ('nbf').
195func (r *RawJWT) HasNotBefore() bool {
196	return r.hasField(claimNotBefore)
197}
198
199// NotBefore returns the not before claim ('nbf') or an error if no claim is present.
200func (r *RawJWT) NotBefore() (time.Time, error) {
201	return r.timeClaim(claimNotBefore)
202}
203
204// HasStringClaim checks whether a claim of type string is present.
205func (r *RawJWT) HasStringClaim(name string) bool {
206	return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StringValue{}})
207}
208
209// StringClaim returns a custom string claim or an error if no claim is present.
210func (r *RawJWT) StringClaim(name string) (string, error) {
211	if isRegisteredClaim(name) {
212		return "", fmt.Errorf("claim '%q' is a registered claim", name)
213	}
214	return r.stringClaim(name)
215}
216
217// HasNumberClaim checks whether a claim of type number is present.
218func (r *RawJWT) HasNumberClaim(name string) bool {
219	return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NumberValue{}})
220}
221
222// NumberClaim returns a custom number claim or an error if no claim is present.
223func (r *RawJWT) NumberClaim(name string) (float64, error) {
224	if isRegisteredClaim(name) {
225		return 0, fmt.Errorf("claim '%q' is a registered claim", name)
226	}
227	return r.numberClaim(name)
228}
229
230// HasBooleanClaim checks whether a claim of type boolean is present.
231func (r *RawJWT) HasBooleanClaim(name string) bool {
232	return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_BoolValue{}})
233}
234
235// BooleanClaim returns a custom bool claim or an error if no claim is present.
236func (r *RawJWT) BooleanClaim(name string) (bool, error) {
237	val, err := r.customClaim(name)
238	if err != nil {
239		return false, err
240	}
241	b, ok := val.Kind.(*spb.Value_BoolValue)
242	if !ok {
243		return false, fmt.Errorf("claim '%q' is not a boolean", name)
244	}
245	return b.BoolValue, nil
246}
247
248// HasNullClaim checks whether a claim of type null is present.
249func (r *RawJWT) HasNullClaim(name string) bool {
250	return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_NullValue{}})
251}
252
253// HasArrayClaim checks whether a claim of type list is present.
254func (r *RawJWT) HasArrayClaim(name string) bool {
255	return !isRegisteredClaim(name) && r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_ListValue{}})
256}
257
258// ArrayClaim returns a slice representing a JSON array for a claim or an error if the claim is empty.
259func (r *RawJWT) ArrayClaim(name string) ([]interface{}, error) {
260	val, err := r.customClaim(name)
261	if err != nil {
262		return nil, err
263	}
264	if val.GetListValue() == nil {
265		return nil, fmt.Errorf("claim '%q' is not a list", name)
266	}
267	return val.GetListValue().AsSlice(), nil
268}
269
270// HasObjectClaim checks whether a claim of type JSON object is present.
271func (r *RawJWT) HasObjectClaim(name string) bool {
272	return r.hasClaimOfKind(name, &spb.Value{Kind: &spb.Value_StructValue{}})
273}
274
275// ObjectClaim returns a map representing a JSON object for a claim or an error if the claim is empty.
276func (r *RawJWT) ObjectClaim(name string) (map[string]interface{}, error) {
277	val, err := r.customClaim(name)
278	if err != nil {
279		return nil, err
280	}
281	if val.GetStructValue() == nil {
282		return nil, fmt.Errorf("claim '%q' is not a JSON object", name)
283	}
284	return val.GetStructValue().AsMap(), err
285}
286
287// CustomClaimNames returns a list with the name of custom claims in a RawJWT.
288func (r *RawJWT) CustomClaimNames() []string {
289	names := []string{}
290	for key := range r.jsonpb.GetFields() {
291		if !isRegisteredClaim(key) {
292			names = append(names, key)
293		}
294	}
295	return names
296}
297
298func (r *RawJWT) timeClaim(name string) (time.Time, error) {
299	n, err := r.numberClaim(name)
300	if err != nil {
301		return time.Time{}, err
302	}
303	return time.Unix(int64(n), 0), err
304}
305
306func (r *RawJWT) numberClaim(name string) (float64, error) {
307	val, ok := r.field(name)
308	if !ok {
309		return 0, fmt.Errorf("no '%q' claim found", name)
310	}
311	s, ok := val.Kind.(*spb.Value_NumberValue)
312	if !ok {
313		return 0, fmt.Errorf("claim '%q' is not a number", name)
314	}
315	return s.NumberValue, nil
316}
317
318func (r *RawJWT) stringClaim(name string) (string, error) {
319	val, ok := r.field(name)
320	if !ok {
321		return "", fmt.Errorf("no '%q' claim found", name)
322	}
323	s, ok := val.Kind.(*spb.Value_StringValue)
324	if !ok {
325		return "", fmt.Errorf("claim '%q' is not a string", name)
326	}
327	if !utf8.ValidString(s.StringValue) {
328		return "", fmt.Errorf("claim '%q' is not a valid utf-8 encoded string", name)
329	}
330	return s.StringValue, nil
331}
332
333func (r *RawJWT) hasClaimOfKind(name string, exp *spb.Value) bool {
334	val, exist := r.field(name)
335	if !exist || exp == nil {
336		return false
337	}
338	var isKind bool
339	switch exp.GetKind().(type) {
340	case *spb.Value_StructValue:
341		_, isKind = val.GetKind().(*spb.Value_StructValue)
342	case *spb.Value_NullValue:
343		_, isKind = val.GetKind().(*spb.Value_NullValue)
344	case *spb.Value_BoolValue:
345		_, isKind = val.GetKind().(*spb.Value_BoolValue)
346	case *spb.Value_ListValue:
347		_, isKind = val.GetKind().(*spb.Value_ListValue)
348	case *spb.Value_StringValue:
349		_, isKind = val.GetKind().(*spb.Value_StringValue)
350	case *spb.Value_NumberValue:
351		_, isKind = val.GetKind().(*spb.Value_NumberValue)
352	default:
353		isKind = false
354	}
355	return isKind
356}
357
358func (r *RawJWT) customClaim(name string) (*spb.Value, error) {
359	if isRegisteredClaim(name) {
360		return nil, fmt.Errorf("'%q' is a registered claim", name)
361	}
362	val, ok := r.field(name)
363	if !ok {
364		return nil, fmt.Errorf("claim '%q' not found", name)
365	}
366	return val, nil
367}
368
369func (r *RawJWT) hasField(name string) bool {
370	_, ok := r.field(name)
371	return ok
372}
373
374func (r *RawJWT) field(name string) (*spb.Value, bool) {
375	val, ok := r.jsonpb.GetFields()[name]
376	return val, ok
377}
378
379// createPayload creates a JSON payload from JWT options.
380func createPayload(opts *RawJWTOptions) (*spb.Struct, error) {
381	if err := validateCustomClaims(opts.CustomClaims); err != nil {
382		return nil, err
383	}
384	if opts.ExpiresAt == nil && !opts.WithoutExpiration {
385		return nil, fmt.Errorf("jwt options must contain an expiration or must be marked WithoutExpiration")
386	}
387	if opts.ExpiresAt != nil && opts.WithoutExpiration {
388		return nil, fmt.Errorf("jwt options can't be marked WithoutExpiration when expiration is specified")
389	}
390	if opts.Audience != nil && opts.Audiences != nil {
391		return nil, fmt.Errorf("jwt options can either contain a single Audience or a list of Audiences but not both")
392	}
393
394	payload := &spb.Struct{
395		Fields: map[string]*spb.Value{},
396	}
397	setStringValue(payload, claimJWTID, opts.JWTID)
398	setStringValue(payload, claimIssuer, opts.Issuer)
399	setStringValue(payload, claimSubject, opts.Subject)
400	setStringValue(payload, claimAudience, opts.Audience)
401	setTimeValue(payload, claimIssuedAt, opts.IssuedAt)
402	setTimeValue(payload, claimNotBefore, opts.NotBefore)
403	setTimeValue(payload, claimExpiration, opts.ExpiresAt)
404	setAudiences(payload, claimAudience, opts.Audiences)
405
406	for k, v := range opts.CustomClaims {
407		val, err := spb.NewValue(v)
408		if err != nil {
409			return nil, err
410		}
411		setValue(payload, k, val)
412	}
413	return payload, nil
414}
415
416func validatePayload(payload *spb.Struct) error {
417	if payload.Fields == nil || len(payload.Fields) == 0 {
418		return nil
419	}
420	if err := validateAudienceClaim(payload.Fields[claimAudience]); err != nil {
421		return err
422	}
423	for claim, val := range payload.GetFields() {
424		if isRegisteredTimeClaim(claim) {
425			if err := validateTimeClaim(claim, val); err != nil {
426				return err
427			}
428		}
429
430		if isRegisteredStringClaim(claim) {
431			if err := validateStringClaim(claim, val); err != nil {
432				return err
433			}
434		}
435	}
436	return nil
437}
438
439func validateStringClaim(claim string, val *spb.Value) error {
440	v, ok := val.Kind.(*spb.Value_StringValue)
441	if !ok {
442		return fmt.Errorf("claim: '%q' MUST be a string", claim)
443	}
444	if !utf8.ValidString(v.StringValue) {
445		return fmt.Errorf("claim: '%q' isn't a valid UTF-8 string", claim)
446	}
447	return nil
448}
449
450func validateTimeClaim(claim string, val *spb.Value) error {
451	if _, ok := val.Kind.(*spb.Value_NumberValue); !ok {
452		return fmt.Errorf("claim %q MUST be a numeric value, ", claim)
453	}
454	t := int64(val.GetNumberValue())
455	if t > jwtTimestampMax || t < jwtTimestampMin {
456		return fmt.Errorf("invalid timestamp: '%d' for claim: %q", t, claim)
457	}
458	return nil
459}
460
461func validateAudienceClaim(val *spb.Value) error {
462	if val == nil {
463		return nil
464	}
465	_, isString := val.Kind.(*spb.Value_StringValue)
466	l, isList := val.Kind.(*spb.Value_ListValue)
467	if !isList && !isString {
468		return fmt.Errorf("audience claim MUST be a list with at least one string or a single string value")
469	}
470	if isString {
471		return validateStringClaim(claimAudience, val)
472	}
473	if l.ListValue != nil && len(l.ListValue.Values) == 0 {
474		return fmt.Errorf("there MUST be at least one value present in the audience claim")
475	}
476	for _, aud := range l.ListValue.Values {
477		v, ok := aud.Kind.(*spb.Value_StringValue)
478		if !ok {
479			return fmt.Errorf("audience value is not a string")
480		}
481		if !utf8.ValidString(v.StringValue) {
482			return fmt.Errorf("audience value is not a valid UTF-8 string")
483		}
484	}
485	return nil
486}
487
488func validateCustomClaims(cc map[string]interface{}) error {
489	if cc == nil {
490		return nil
491	}
492	for key := range cc {
493		if isRegisteredClaim(key) {
494			return fmt.Errorf("claim '%q' is a registered claim, it can't be declared as a custom claim", key)
495		}
496	}
497	return nil
498}
499
500func setTimeValue(p *spb.Struct, claim string, val *time.Time) {
501	if val == nil {
502		return
503	}
504	setValue(p, claim, spb.NewNumberValue(float64(val.Unix())))
505}
506
507func setStringValue(p *spb.Struct, claim string, val *string) {
508	if val == nil {
509		return
510	}
511	setValue(p, claim, spb.NewStringValue(*val))
512}
513
514func setAudiences(p *spb.Struct, claim string, vals []string) {
515	if vals == nil {
516		return
517	}
518	audList := &spb.ListValue{
519		Values: make([]*spb.Value, 0, len(vals)),
520	}
521	for _, aud := range vals {
522		audList.Values = append(audList.Values, spb.NewStringValue(aud))
523	}
524	setValue(p, claim, spb.NewListValue(audList))
525}
526
527func setValue(p *spb.Struct, claim string, val *spb.Value) {
528	if p.GetFields() == nil {
529		p.Fields = make(map[string]*spb.Value)
530	}
531	p.GetFields()[claim] = val
532}
533
534func isRegisteredClaim(c string) bool {
535	return isRegisteredStringClaim(c) || isRegisteredTimeClaim(c) || c == claimAudience
536}
537
538func isRegisteredStringClaim(c string) bool {
539	return c == claimIssuer || c == claimSubject || c == claimJWTID
540}
541
542func isRegisteredTimeClaim(c string) bool {
543	return c == claimExpiration || c == claimNotBefore || c == claimIssuedAt
544}
545