1*e7b1675dSTing-Kang Chang// Copyright 2022 Google LLC 2*e7b1675dSTing-Kang Chang// 3*e7b1675dSTing-Kang Chang// Licensed under the Apache License, Version 2.0 (the "License"); 4*e7b1675dSTing-Kang Chang// you may not use this file except in compliance with the License. 5*e7b1675dSTing-Kang Chang// You may obtain a copy of the License at 6*e7b1675dSTing-Kang Chang// 7*e7b1675dSTing-Kang Chang// http://www.apache.org/licenses/LICENSE-2.0 8*e7b1675dSTing-Kang Chang// 9*e7b1675dSTing-Kang Chang// Unless required by applicable law or agreed to in writing, software 10*e7b1675dSTing-Kang Chang// distributed under the License is distributed on an "AS IS" BASIS, 11*e7b1675dSTing-Kang Chang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*e7b1675dSTing-Kang Chang// See the License for the specific language governing permissions and 13*e7b1675dSTing-Kang Chang// limitations under the License. 14*e7b1675dSTing-Kang Chang// 15*e7b1675dSTing-Kang Chang//////////////////////////////////////////////////////////////////////////////// 16*e7b1675dSTing-Kang Chang 17*e7b1675dSTing-Kang Changpackage jwt 18*e7b1675dSTing-Kang Chang 19*e7b1675dSTing-Kang Changimport ( 20*e7b1675dSTing-Kang Chang "encoding/base64" 21*e7b1675dSTing-Kang Chang "encoding/binary" 22*e7b1675dSTing-Kang Chang "fmt" 23*e7b1675dSTing-Kang Chang "strings" 24*e7b1675dSTing-Kang Chang 25*e7b1675dSTing-Kang Chang spb "google.golang.org/protobuf/types/known/structpb" 26*e7b1675dSTing-Kang Chang tpb "github.com/google/tink/go/proto/tink_go_proto" 27*e7b1675dSTing-Kang Chang) 28*e7b1675dSTing-Kang Chang 29*e7b1675dSTing-Kang Chang// keyID returns the keyID in big endian format base64 encoded if the key output prefix is of type Tink or nil otherwise. 30*e7b1675dSTing-Kang Changfunc keyID(keyID uint32, outPrefixType tpb.OutputPrefixType) *string { 31*e7b1675dSTing-Kang Chang if outPrefixType != tpb.OutputPrefixType_TINK { 32*e7b1675dSTing-Kang Chang return nil 33*e7b1675dSTing-Kang Chang } 34*e7b1675dSTing-Kang Chang buf := make([]byte, 4) 35*e7b1675dSTing-Kang Chang binary.BigEndian.PutUint32(buf, keyID) 36*e7b1675dSTing-Kang Chang s := base64Encode(buf) 37*e7b1675dSTing-Kang Chang return &s 38*e7b1675dSTing-Kang Chang} 39*e7b1675dSTing-Kang Chang 40*e7b1675dSTing-Kang Chang// createUnsigned creates an unsigned JWT by created the header/payload, encoding them to a websafe base64 encoded string and concatenating. 41*e7b1675dSTing-Kang Changfunc createUnsigned(rawJWT *RawJWT, algo string, tinkKID *string, customKID *string) (string, error) { 42*e7b1675dSTing-Kang Chang if rawJWT == nil { 43*e7b1675dSTing-Kang Chang return "", fmt.Errorf("rawJWT is nil") 44*e7b1675dSTing-Kang Chang } 45*e7b1675dSTing-Kang Chang var typeHeader *string = nil 46*e7b1675dSTing-Kang Chang if rawJWT.HasTypeHeader() { 47*e7b1675dSTing-Kang Chang th, err := rawJWT.TypeHeader() 48*e7b1675dSTing-Kang Chang if err != nil { 49*e7b1675dSTing-Kang Chang return "", err 50*e7b1675dSTing-Kang Chang } 51*e7b1675dSTing-Kang Chang typeHeader = &th 52*e7b1675dSTing-Kang Chang } 53*e7b1675dSTing-Kang Chang if customKID != nil && tinkKID != nil { 54*e7b1675dSTing-Kang Chang return "", fmt.Errorf("TINK Keys are not allowed to have a kid value set") 55*e7b1675dSTing-Kang Chang } 56*e7b1675dSTing-Kang Chang if tinkKID != nil { 57*e7b1675dSTing-Kang Chang customKID = tinkKID 58*e7b1675dSTing-Kang Chang } 59*e7b1675dSTing-Kang Chang encodedHeader, err := createHeader(algo, typeHeader, customKID) 60*e7b1675dSTing-Kang Chang if err != nil { 61*e7b1675dSTing-Kang Chang return "", err 62*e7b1675dSTing-Kang Chang } 63*e7b1675dSTing-Kang Chang payload, err := rawJWT.JSONPayload() 64*e7b1675dSTing-Kang Chang if err != nil { 65*e7b1675dSTing-Kang Chang return "", err 66*e7b1675dSTing-Kang Chang } 67*e7b1675dSTing-Kang Chang return dotConcat(encodedHeader, base64Encode(payload)), nil 68*e7b1675dSTing-Kang Chang} 69*e7b1675dSTing-Kang Chang 70*e7b1675dSTing-Kang Chang// combineUnsignedAndSignature combines the token with the raw signature to provide a signed token. 71*e7b1675dSTing-Kang Changfunc combineUnsignedAndSignature(unsigned string, signature []byte) string { 72*e7b1675dSTing-Kang Chang return dotConcat(unsigned, base64Encode(signature)) 73*e7b1675dSTing-Kang Chang} 74*e7b1675dSTing-Kang Chang 75*e7b1675dSTing-Kang Chang// splitSignedCompact extracts the witness and usigned JWT. 76*e7b1675dSTing-Kang Changfunc splitSignedCompact(compact string) ([]byte, string, error) { 77*e7b1675dSTing-Kang Chang i := strings.LastIndex(compact, ".") 78*e7b1675dSTing-Kang Chang if i < 0 { 79*e7b1675dSTing-Kang Chang return nil, "", fmt.Errorf("invalid token") 80*e7b1675dSTing-Kang Chang } 81*e7b1675dSTing-Kang Chang witness, err := base64Decode(compact[i+1:]) 82*e7b1675dSTing-Kang Chang if err != nil { 83*e7b1675dSTing-Kang Chang return nil, "", fmt.Errorf("%q: %v", compact[i+1:], err) 84*e7b1675dSTing-Kang Chang } 85*e7b1675dSTing-Kang Chang if len(witness) == 0 { 86*e7b1675dSTing-Kang Chang return nil, "", fmt.Errorf("empty signature") 87*e7b1675dSTing-Kang Chang } 88*e7b1675dSTing-Kang Chang unsigned := compact[0:i] 89*e7b1675dSTing-Kang Chang if len(unsigned) == 0 { 90*e7b1675dSTing-Kang Chang return nil, "", fmt.Errorf("empty content") 91*e7b1675dSTing-Kang Chang } 92*e7b1675dSTing-Kang Chang if strings.Count(unsigned, ".") != 1 { 93*e7b1675dSTing-Kang Chang return nil, "", fmt.Errorf("only tokens in JWS compact serialization formats are supported") 94*e7b1675dSTing-Kang Chang } 95*e7b1675dSTing-Kang Chang return witness, unsigned, nil 96*e7b1675dSTing-Kang Chang} 97*e7b1675dSTing-Kang Chang 98*e7b1675dSTing-Kang Chang// decodeUnsignedTokenAndValidateHeader verifies the header on an unsigned JWT and decodes the payload into a RawJWT. 99*e7b1675dSTing-Kang Chang// Expects the token to be in compact serialization format. The signature should be verified before calling this function. 100*e7b1675dSTing-Kang Changfunc decodeUnsignedTokenAndValidateHeader(unsigned, algorithm string, tinkKID, customKID *string) (*RawJWT, error) { 101*e7b1675dSTing-Kang Chang parts := strings.Split(unsigned, ".") 102*e7b1675dSTing-Kang Chang if len(parts) != 2 { 103*e7b1675dSTing-Kang Chang return nil, fmt.Errorf("only tokens in JWS compact serialization formats are supported") 104*e7b1675dSTing-Kang Chang } 105*e7b1675dSTing-Kang Chang jsonHeader, err := base64Decode(parts[0]) 106*e7b1675dSTing-Kang Chang if err != nil { 107*e7b1675dSTing-Kang Chang return nil, err 108*e7b1675dSTing-Kang Chang } 109*e7b1675dSTing-Kang Chang header, err := jsonToStruct(jsonHeader) 110*e7b1675dSTing-Kang Chang if err != nil { 111*e7b1675dSTing-Kang Chang return nil, err 112*e7b1675dSTing-Kang Chang } 113*e7b1675dSTing-Kang Chang if err := validateHeader(header, algorithm, tinkKID, customKID); err != nil { 114*e7b1675dSTing-Kang Chang return nil, err 115*e7b1675dSTing-Kang Chang } 116*e7b1675dSTing-Kang Chang typeHeader, err := extractTypeHeader(header) 117*e7b1675dSTing-Kang Chang if err != nil { 118*e7b1675dSTing-Kang Chang return nil, err 119*e7b1675dSTing-Kang Chang } 120*e7b1675dSTing-Kang Chang jsonPayload, err := base64Decode(parts[1]) 121*e7b1675dSTing-Kang Chang if err != nil { 122*e7b1675dSTing-Kang Chang return nil, err 123*e7b1675dSTing-Kang Chang } 124*e7b1675dSTing-Kang Chang return NewRawJWTFromJSON(typeHeader, jsonPayload) 125*e7b1675dSTing-Kang Chang} 126*e7b1675dSTing-Kang Chang 127*e7b1675dSTing-Kang Chang// base64Encode encodes a byte array into a base64 URL safe string with no padding. 128*e7b1675dSTing-Kang Changfunc base64Encode(content []byte) string { 129*e7b1675dSTing-Kang Chang return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(content) 130*e7b1675dSTing-Kang Chang} 131*e7b1675dSTing-Kang Chang 132*e7b1675dSTing-Kang Chang// base64Decode decodes a URL safe base64 encoded string into a byte array ignoring padding. 133*e7b1675dSTing-Kang Changfunc base64Decode(content string) ([]byte, error) { 134*e7b1675dSTing-Kang Chang for _, c := range content { 135*e7b1675dSTing-Kang Chang if !isValidURLsafeBase64Char(c) { 136*e7b1675dSTing-Kang Chang return nil, fmt.Errorf("invalid encoding") 137*e7b1675dSTing-Kang Chang } 138*e7b1675dSTing-Kang Chang } 139*e7b1675dSTing-Kang Chang return base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(content) 140*e7b1675dSTing-Kang Chang} 141*e7b1675dSTing-Kang Chang 142*e7b1675dSTing-Kang Changfunc isValidURLsafeBase64Char(c rune) bool { 143*e7b1675dSTing-Kang Chang return (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')) || 144*e7b1675dSTing-Kang Chang ((c >= '0') && (c <= '9')) || ((c == '-') || (c == '_'))) 145*e7b1675dSTing-Kang Chang} 146*e7b1675dSTing-Kang Chang 147*e7b1675dSTing-Kang Changfunc dotConcat(a, b string) string { 148*e7b1675dSTing-Kang Chang return fmt.Sprintf("%s.%s", a, b) 149*e7b1675dSTing-Kang Chang} 150*e7b1675dSTing-Kang Chang 151*e7b1675dSTing-Kang Changfunc jsonToStruct(jsonPayload []byte) (*spb.Struct, error) { 152*e7b1675dSTing-Kang Chang payload := &spb.Struct{} 153*e7b1675dSTing-Kang Chang if err := payload.UnmarshalJSON(jsonPayload); err != nil { 154*e7b1675dSTing-Kang Chang return nil, err 155*e7b1675dSTing-Kang Chang } 156*e7b1675dSTing-Kang Chang return payload, nil 157*e7b1675dSTing-Kang Chang} 158*e7b1675dSTing-Kang Chang 159*e7b1675dSTing-Kang Changfunc extractTypeHeader(header *spb.Struct) (*string, error) { 160*e7b1675dSTing-Kang Chang fields := header.GetFields() 161*e7b1675dSTing-Kang Chang if fields == nil { 162*e7b1675dSTing-Kang Chang return nil, fmt.Errorf("header contains no fields") 163*e7b1675dSTing-Kang Chang } 164*e7b1675dSTing-Kang Chang val, ok := fields["typ"] 165*e7b1675dSTing-Kang Chang if !ok { 166*e7b1675dSTing-Kang Chang return nil, nil 167*e7b1675dSTing-Kang Chang } 168*e7b1675dSTing-Kang Chang str, ok := val.Kind.(*spb.Value_StringValue) 169*e7b1675dSTing-Kang Chang if !ok { 170*e7b1675dSTing-Kang Chang return nil, fmt.Errorf("type header isn't a string") 171*e7b1675dSTing-Kang Chang } 172*e7b1675dSTing-Kang Chang return &str.StringValue, nil 173*e7b1675dSTing-Kang Chang} 174*e7b1675dSTing-Kang Chang 175*e7b1675dSTing-Kang Changfunc createHeader(algorithm string, typeHeader, kid *string) (string, error) { 176*e7b1675dSTing-Kang Chang header := &spb.Struct{ 177*e7b1675dSTing-Kang Chang Fields: map[string]*spb.Value{ 178*e7b1675dSTing-Kang Chang "alg": spb.NewStringValue(algorithm), 179*e7b1675dSTing-Kang Chang }, 180*e7b1675dSTing-Kang Chang } 181*e7b1675dSTing-Kang Chang if typeHeader != nil { 182*e7b1675dSTing-Kang Chang header.Fields["typ"] = spb.NewStringValue(*typeHeader) 183*e7b1675dSTing-Kang Chang } 184*e7b1675dSTing-Kang Chang if kid != nil { 185*e7b1675dSTing-Kang Chang header.Fields["kid"] = spb.NewStringValue(*kid) 186*e7b1675dSTing-Kang Chang } 187*e7b1675dSTing-Kang Chang jsonHeader, err := header.MarshalJSON() 188*e7b1675dSTing-Kang Chang if err != nil { 189*e7b1675dSTing-Kang Chang return "", err 190*e7b1675dSTing-Kang Chang } 191*e7b1675dSTing-Kang Chang return base64Encode(jsonHeader), nil 192*e7b1675dSTing-Kang Chang} 193*e7b1675dSTing-Kang Chang 194*e7b1675dSTing-Kang Changfunc validateHeader(header *spb.Struct, algorithm string, tinkKID, customKID *string) error { 195*e7b1675dSTing-Kang Chang fields := header.GetFields() 196*e7b1675dSTing-Kang Chang if fields == nil { 197*e7b1675dSTing-Kang Chang return fmt.Errorf("header contains no fields") 198*e7b1675dSTing-Kang Chang } 199*e7b1675dSTing-Kang Chang alg, err := headerStringField(fields, "alg") 200*e7b1675dSTing-Kang Chang if err != nil { 201*e7b1675dSTing-Kang Chang return err 202*e7b1675dSTing-Kang Chang } 203*e7b1675dSTing-Kang Chang if alg != algorithm { 204*e7b1675dSTing-Kang Chang return fmt.Errorf("invalid alg") 205*e7b1675dSTing-Kang Chang } 206*e7b1675dSTing-Kang Chang if _, ok := fields["crit"]; ok { 207*e7b1675dSTing-Kang Chang return fmt.Errorf("all tokens with crit headers are rejected") 208*e7b1675dSTing-Kang Chang } 209*e7b1675dSTing-Kang Chang if tinkKID != nil && customKID != nil { 210*e7b1675dSTing-Kang Chang return fmt.Errorf("custom_kid can only be set for RAW keys") 211*e7b1675dSTing-Kang Chang } 212*e7b1675dSTing-Kang Chang _, hasKID := fields["kid"] 213*e7b1675dSTing-Kang Chang if tinkKID != nil && !hasKID { 214*e7b1675dSTing-Kang Chang return fmt.Errorf("missing kid in header") 215*e7b1675dSTing-Kang Chang } 216*e7b1675dSTing-Kang Chang if tinkKID != nil { 217*e7b1675dSTing-Kang Chang return validateKIDInHeader(fields, tinkKID) 218*e7b1675dSTing-Kang Chang } 219*e7b1675dSTing-Kang Chang if hasKID && customKID != nil { 220*e7b1675dSTing-Kang Chang return validateKIDInHeader(fields, customKID) 221*e7b1675dSTing-Kang Chang } 222*e7b1675dSTing-Kang Chang return nil 223*e7b1675dSTing-Kang Chang} 224*e7b1675dSTing-Kang Chang 225*e7b1675dSTing-Kang Changfunc validateKIDInHeader(fields map[string]*spb.Value, kid *string) error { 226*e7b1675dSTing-Kang Chang headerKID, err := headerStringField(fields, "kid") 227*e7b1675dSTing-Kang Chang if err != nil { 228*e7b1675dSTing-Kang Chang return err 229*e7b1675dSTing-Kang Chang } 230*e7b1675dSTing-Kang Chang if headerKID != *kid { 231*e7b1675dSTing-Kang Chang return fmt.Errorf("invalid kid header") 232*e7b1675dSTing-Kang Chang } 233*e7b1675dSTing-Kang Chang return nil 234*e7b1675dSTing-Kang Chang} 235*e7b1675dSTing-Kang Chang 236*e7b1675dSTing-Kang Changfunc headerStringField(fields map[string]*spb.Value, name string) (string, error) { 237*e7b1675dSTing-Kang Chang val, ok := fields[name] 238*e7b1675dSTing-Kang Chang if !ok { 239*e7b1675dSTing-Kang Chang return "", fmt.Errorf("header is missing %q", name) 240*e7b1675dSTing-Kang Chang } 241*e7b1675dSTing-Kang Chang str, ok := val.Kind.(*spb.Value_StringValue) 242*e7b1675dSTing-Kang Chang if !ok { 243*e7b1675dSTing-Kang Chang return "", fmt.Errorf("%q header isn't a string", name) 244*e7b1675dSTing-Kang Chang } 245*e7b1675dSTing-Kang Chang return str.StringValue, nil 246*e7b1675dSTing-Kang Chang} 247