// Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //////////////////////////////////////////////////////////////////////////////// package jwt import ( "encoding/base64" "encoding/binary" "fmt" "strings" spb "google.golang.org/protobuf/types/known/structpb" tpb "github.com/google/tink/go/proto/tink_go_proto" ) // keyID returns the keyID in big endian format base64 encoded if the key output prefix is of type Tink or nil otherwise. func keyID(keyID uint32, outPrefixType tpb.OutputPrefixType) *string { if outPrefixType != tpb.OutputPrefixType_TINK { return nil } buf := make([]byte, 4) binary.BigEndian.PutUint32(buf, keyID) s := base64Encode(buf) return &s } // createUnsigned creates an unsigned JWT by created the header/payload, encoding them to a websafe base64 encoded string and concatenating. func createUnsigned(rawJWT *RawJWT, algo string, tinkKID *string, customKID *string) (string, error) { if rawJWT == nil { return "", fmt.Errorf("rawJWT is nil") } var typeHeader *string = nil if rawJWT.HasTypeHeader() { th, err := rawJWT.TypeHeader() if err != nil { return "", err } typeHeader = &th } if customKID != nil && tinkKID != nil { return "", fmt.Errorf("TINK Keys are not allowed to have a kid value set") } if tinkKID != nil { customKID = tinkKID } encodedHeader, err := createHeader(algo, typeHeader, customKID) if err != nil { return "", err } payload, err := rawJWT.JSONPayload() if err != nil { return "", err } return dotConcat(encodedHeader, base64Encode(payload)), nil } // combineUnsignedAndSignature combines the token with the raw signature to provide a signed token. func combineUnsignedAndSignature(unsigned string, signature []byte) string { return dotConcat(unsigned, base64Encode(signature)) } // splitSignedCompact extracts the witness and usigned JWT. func splitSignedCompact(compact string) ([]byte, string, error) { i := strings.LastIndex(compact, ".") if i < 0 { return nil, "", fmt.Errorf("invalid token") } witness, err := base64Decode(compact[i+1:]) if err != nil { return nil, "", fmt.Errorf("%q: %v", compact[i+1:], err) } if len(witness) == 0 { return nil, "", fmt.Errorf("empty signature") } unsigned := compact[0:i] if len(unsigned) == 0 { return nil, "", fmt.Errorf("empty content") } if strings.Count(unsigned, ".") != 1 { return nil, "", fmt.Errorf("only tokens in JWS compact serialization formats are supported") } return witness, unsigned, nil } // decodeUnsignedTokenAndValidateHeader verifies the header on an unsigned JWT and decodes the payload into a RawJWT. // Expects the token to be in compact serialization format. The signature should be verified before calling this function. func decodeUnsignedTokenAndValidateHeader(unsigned, algorithm string, tinkKID, customKID *string) (*RawJWT, error) { parts := strings.Split(unsigned, ".") if len(parts) != 2 { return nil, fmt.Errorf("only tokens in JWS compact serialization formats are supported") } jsonHeader, err := base64Decode(parts[0]) if err != nil { return nil, err } header, err := jsonToStruct(jsonHeader) if err != nil { return nil, err } if err := validateHeader(header, algorithm, tinkKID, customKID); err != nil { return nil, err } typeHeader, err := extractTypeHeader(header) if err != nil { return nil, err } jsonPayload, err := base64Decode(parts[1]) if err != nil { return nil, err } return NewRawJWTFromJSON(typeHeader, jsonPayload) } // base64Encode encodes a byte array into a base64 URL safe string with no padding. func base64Encode(content []byte) string { return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(content) } // base64Decode decodes a URL safe base64 encoded string into a byte array ignoring padding. func base64Decode(content string) ([]byte, error) { for _, c := range content { if !isValidURLsafeBase64Char(c) { return nil, fmt.Errorf("invalid encoding") } } return base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(content) } func isValidURLsafeBase64Char(c rune) bool { return (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')) || ((c >= '0') && (c <= '9')) || ((c == '-') || (c == '_'))) } func dotConcat(a, b string) string { return fmt.Sprintf("%s.%s", a, b) } func jsonToStruct(jsonPayload []byte) (*spb.Struct, error) { payload := &spb.Struct{} if err := payload.UnmarshalJSON(jsonPayload); err != nil { return nil, err } return payload, nil } func extractTypeHeader(header *spb.Struct) (*string, error) { fields := header.GetFields() if fields == nil { return nil, fmt.Errorf("header contains no fields") } val, ok := fields["typ"] if !ok { return nil, nil } str, ok := val.Kind.(*spb.Value_StringValue) if !ok { return nil, fmt.Errorf("type header isn't a string") } return &str.StringValue, nil } func createHeader(algorithm string, typeHeader, kid *string) (string, error) { header := &spb.Struct{ Fields: map[string]*spb.Value{ "alg": spb.NewStringValue(algorithm), }, } if typeHeader != nil { header.Fields["typ"] = spb.NewStringValue(*typeHeader) } if kid != nil { header.Fields["kid"] = spb.NewStringValue(*kid) } jsonHeader, err := header.MarshalJSON() if err != nil { return "", err } return base64Encode(jsonHeader), nil } func validateHeader(header *spb.Struct, algorithm string, tinkKID, customKID *string) error { fields := header.GetFields() if fields == nil { return fmt.Errorf("header contains no fields") } alg, err := headerStringField(fields, "alg") if err != nil { return err } if alg != algorithm { return fmt.Errorf("invalid alg") } if _, ok := fields["crit"]; ok { return fmt.Errorf("all tokens with crit headers are rejected") } if tinkKID != nil && customKID != nil { return fmt.Errorf("custom_kid can only be set for RAW keys") } _, hasKID := fields["kid"] if tinkKID != nil && !hasKID { return fmt.Errorf("missing kid in header") } if tinkKID != nil { return validateKIDInHeader(fields, tinkKID) } if hasKID && customKID != nil { return validateKIDInHeader(fields, customKID) } return nil } func validateKIDInHeader(fields map[string]*spb.Value, kid *string) error { headerKID, err := headerStringField(fields, "kid") if err != nil { return err } if headerKID != *kid { return fmt.Errorf("invalid kid header") } return nil } func headerStringField(fields map[string]*spb.Value, name string) (string, error) { val, ok := fields[name] if !ok { return "", fmt.Errorf("header is missing %q", name) } str, ok := val.Kind.(*spb.Value_StringValue) if !ok { return "", fmt.Errorf("%q header isn't a string", name) } return str.StringValue, nil }