1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package jwt 18 19import ( 20 "encoding/base64" 21 "encoding/binary" 22 "fmt" 23 "strings" 24 25 spb "google.golang.org/protobuf/types/known/structpb" 26 tpb "github.com/google/tink/go/proto/tink_go_proto" 27) 28 29// keyID returns the keyID in big endian format base64 encoded if the key output prefix is of type Tink or nil otherwise. 30func keyID(keyID uint32, outPrefixType tpb.OutputPrefixType) *string { 31 if outPrefixType != tpb.OutputPrefixType_TINK { 32 return nil 33 } 34 buf := make([]byte, 4) 35 binary.BigEndian.PutUint32(buf, keyID) 36 s := base64Encode(buf) 37 return &s 38} 39 40// createUnsigned creates an unsigned JWT by created the header/payload, encoding them to a websafe base64 encoded string and concatenating. 41func createUnsigned(rawJWT *RawJWT, algo string, tinkKID *string, customKID *string) (string, error) { 42 if rawJWT == nil { 43 return "", fmt.Errorf("rawJWT is nil") 44 } 45 var typeHeader *string = nil 46 if rawJWT.HasTypeHeader() { 47 th, err := rawJWT.TypeHeader() 48 if err != nil { 49 return "", err 50 } 51 typeHeader = &th 52 } 53 if customKID != nil && tinkKID != nil { 54 return "", fmt.Errorf("TINK Keys are not allowed to have a kid value set") 55 } 56 if tinkKID != nil { 57 customKID = tinkKID 58 } 59 encodedHeader, err := createHeader(algo, typeHeader, customKID) 60 if err != nil { 61 return "", err 62 } 63 payload, err := rawJWT.JSONPayload() 64 if err != nil { 65 return "", err 66 } 67 return dotConcat(encodedHeader, base64Encode(payload)), nil 68} 69 70// combineUnsignedAndSignature combines the token with the raw signature to provide a signed token. 71func combineUnsignedAndSignature(unsigned string, signature []byte) string { 72 return dotConcat(unsigned, base64Encode(signature)) 73} 74 75// splitSignedCompact extracts the witness and usigned JWT. 76func splitSignedCompact(compact string) ([]byte, string, error) { 77 i := strings.LastIndex(compact, ".") 78 if i < 0 { 79 return nil, "", fmt.Errorf("invalid token") 80 } 81 witness, err := base64Decode(compact[i+1:]) 82 if err != nil { 83 return nil, "", fmt.Errorf("%q: %v", compact[i+1:], err) 84 } 85 if len(witness) == 0 { 86 return nil, "", fmt.Errorf("empty signature") 87 } 88 unsigned := compact[0:i] 89 if len(unsigned) == 0 { 90 return nil, "", fmt.Errorf("empty content") 91 } 92 if strings.Count(unsigned, ".") != 1 { 93 return nil, "", fmt.Errorf("only tokens in JWS compact serialization formats are supported") 94 } 95 return witness, unsigned, nil 96} 97 98// decodeUnsignedTokenAndValidateHeader verifies the header on an unsigned JWT and decodes the payload into a RawJWT. 99// Expects the token to be in compact serialization format. The signature should be verified before calling this function. 100func decodeUnsignedTokenAndValidateHeader(unsigned, algorithm string, tinkKID, customKID *string) (*RawJWT, error) { 101 parts := strings.Split(unsigned, ".") 102 if len(parts) != 2 { 103 return nil, fmt.Errorf("only tokens in JWS compact serialization formats are supported") 104 } 105 jsonHeader, err := base64Decode(parts[0]) 106 if err != nil { 107 return nil, err 108 } 109 header, err := jsonToStruct(jsonHeader) 110 if err != nil { 111 return nil, err 112 } 113 if err := validateHeader(header, algorithm, tinkKID, customKID); err != nil { 114 return nil, err 115 } 116 typeHeader, err := extractTypeHeader(header) 117 if err != nil { 118 return nil, err 119 } 120 jsonPayload, err := base64Decode(parts[1]) 121 if err != nil { 122 return nil, err 123 } 124 return NewRawJWTFromJSON(typeHeader, jsonPayload) 125} 126 127// base64Encode encodes a byte array into a base64 URL safe string with no padding. 128func base64Encode(content []byte) string { 129 return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(content) 130} 131 132// base64Decode decodes a URL safe base64 encoded string into a byte array ignoring padding. 133func base64Decode(content string) ([]byte, error) { 134 for _, c := range content { 135 if !isValidURLsafeBase64Char(c) { 136 return nil, fmt.Errorf("invalid encoding") 137 } 138 } 139 return base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(content) 140} 141 142func isValidURLsafeBase64Char(c rune) bool { 143 return (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')) || 144 ((c >= '0') && (c <= '9')) || ((c == '-') || (c == '_'))) 145} 146 147func dotConcat(a, b string) string { 148 return fmt.Sprintf("%s.%s", a, b) 149} 150 151func jsonToStruct(jsonPayload []byte) (*spb.Struct, error) { 152 payload := &spb.Struct{} 153 if err := payload.UnmarshalJSON(jsonPayload); err != nil { 154 return nil, err 155 } 156 return payload, nil 157} 158 159func extractTypeHeader(header *spb.Struct) (*string, error) { 160 fields := header.GetFields() 161 if fields == nil { 162 return nil, fmt.Errorf("header contains no fields") 163 } 164 val, ok := fields["typ"] 165 if !ok { 166 return nil, nil 167 } 168 str, ok := val.Kind.(*spb.Value_StringValue) 169 if !ok { 170 return nil, fmt.Errorf("type header isn't a string") 171 } 172 return &str.StringValue, nil 173} 174 175func createHeader(algorithm string, typeHeader, kid *string) (string, error) { 176 header := &spb.Struct{ 177 Fields: map[string]*spb.Value{ 178 "alg": spb.NewStringValue(algorithm), 179 }, 180 } 181 if typeHeader != nil { 182 header.Fields["typ"] = spb.NewStringValue(*typeHeader) 183 } 184 if kid != nil { 185 header.Fields["kid"] = spb.NewStringValue(*kid) 186 } 187 jsonHeader, err := header.MarshalJSON() 188 if err != nil { 189 return "", err 190 } 191 return base64Encode(jsonHeader), nil 192} 193 194func validateHeader(header *spb.Struct, algorithm string, tinkKID, customKID *string) error { 195 fields := header.GetFields() 196 if fields == nil { 197 return fmt.Errorf("header contains no fields") 198 } 199 alg, err := headerStringField(fields, "alg") 200 if err != nil { 201 return err 202 } 203 if alg != algorithm { 204 return fmt.Errorf("invalid alg") 205 } 206 if _, ok := fields["crit"]; ok { 207 return fmt.Errorf("all tokens with crit headers are rejected") 208 } 209 if tinkKID != nil && customKID != nil { 210 return fmt.Errorf("custom_kid can only be set for RAW keys") 211 } 212 _, hasKID := fields["kid"] 213 if tinkKID != nil && !hasKID { 214 return fmt.Errorf("missing kid in header") 215 } 216 if tinkKID != nil { 217 return validateKIDInHeader(fields, tinkKID) 218 } 219 if hasKID && customKID != nil { 220 return validateKIDInHeader(fields, customKID) 221 } 222 return nil 223} 224 225func validateKIDInHeader(fields map[string]*spb.Value, kid *string) error { 226 headerKID, err := headerStringField(fields, "kid") 227 if err != nil { 228 return err 229 } 230 if headerKID != *kid { 231 return fmt.Errorf("invalid kid header") 232 } 233 return nil 234} 235 236func headerStringField(fields map[string]*spb.Value, name string) (string, error) { 237 val, ok := fields[name] 238 if !ok { 239 return "", fmt.Errorf("header is missing %q", name) 240 } 241 str, ok := val.Kind.(*spb.Value_StringValue) 242 if !ok { 243 return "", fmt.Errorf("%q header isn't a string", name) 244 } 245 return str.StringValue, nil 246} 247