xref: /aosp_15_r20/external/tink/go/jwt/jwt_encoding.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package jwt
18
19import (
20	"encoding/base64"
21	"encoding/binary"
22	"fmt"
23	"strings"
24
25	spb "google.golang.org/protobuf/types/known/structpb"
26	tpb "github.com/google/tink/go/proto/tink_go_proto"
27)
28
29// keyID returns the keyID in big endian format base64 encoded if the key output prefix is of type Tink or nil otherwise.
30func keyID(keyID uint32, outPrefixType tpb.OutputPrefixType) *string {
31	if outPrefixType != tpb.OutputPrefixType_TINK {
32		return nil
33	}
34	buf := make([]byte, 4)
35	binary.BigEndian.PutUint32(buf, keyID)
36	s := base64Encode(buf)
37	return &s
38}
39
40// createUnsigned creates an unsigned JWT by created the header/payload, encoding them to a websafe base64 encoded string and concatenating.
41func createUnsigned(rawJWT *RawJWT, algo string, tinkKID *string, customKID *string) (string, error) {
42	if rawJWT == nil {
43		return "", fmt.Errorf("rawJWT is nil")
44	}
45	var typeHeader *string = nil
46	if rawJWT.HasTypeHeader() {
47		th, err := rawJWT.TypeHeader()
48		if err != nil {
49			return "", err
50		}
51		typeHeader = &th
52	}
53	if customKID != nil && tinkKID != nil {
54		return "", fmt.Errorf("TINK Keys are not allowed to have a kid value set")
55	}
56	if tinkKID != nil {
57		customKID = tinkKID
58	}
59	encodedHeader, err := createHeader(algo, typeHeader, customKID)
60	if err != nil {
61		return "", err
62	}
63	payload, err := rawJWT.JSONPayload()
64	if err != nil {
65		return "", err
66	}
67	return dotConcat(encodedHeader, base64Encode(payload)), nil
68}
69
70// combineUnsignedAndSignature combines the token with the raw signature to provide a signed token.
71func combineUnsignedAndSignature(unsigned string, signature []byte) string {
72	return dotConcat(unsigned, base64Encode(signature))
73}
74
75// splitSignedCompact extracts the witness and usigned JWT.
76func splitSignedCompact(compact string) ([]byte, string, error) {
77	i := strings.LastIndex(compact, ".")
78	if i < 0 {
79		return nil, "", fmt.Errorf("invalid token")
80	}
81	witness, err := base64Decode(compact[i+1:])
82	if err != nil {
83		return nil, "", fmt.Errorf("%q: %v", compact[i+1:], err)
84	}
85	if len(witness) == 0 {
86		return nil, "", fmt.Errorf("empty signature")
87	}
88	unsigned := compact[0:i]
89	if len(unsigned) == 0 {
90		return nil, "", fmt.Errorf("empty content")
91	}
92	if strings.Count(unsigned, ".") != 1 {
93		return nil, "", fmt.Errorf("only tokens in JWS compact serialization formats are supported")
94	}
95	return witness, unsigned, nil
96}
97
98// decodeUnsignedTokenAndValidateHeader verifies the header on an unsigned JWT and decodes the payload into a RawJWT.
99// Expects the token to be in compact serialization format. The signature should be verified before calling this function.
100func decodeUnsignedTokenAndValidateHeader(unsigned, algorithm string, tinkKID, customKID *string) (*RawJWT, error) {
101	parts := strings.Split(unsigned, ".")
102	if len(parts) != 2 {
103		return nil, fmt.Errorf("only tokens in JWS compact serialization formats are supported")
104	}
105	jsonHeader, err := base64Decode(parts[0])
106	if err != nil {
107		return nil, err
108	}
109	header, err := jsonToStruct(jsonHeader)
110	if err != nil {
111		return nil, err
112	}
113	if err := validateHeader(header, algorithm, tinkKID, customKID); err != nil {
114		return nil, err
115	}
116	typeHeader, err := extractTypeHeader(header)
117	if err != nil {
118		return nil, err
119	}
120	jsonPayload, err := base64Decode(parts[1])
121	if err != nil {
122		return nil, err
123	}
124	return NewRawJWTFromJSON(typeHeader, jsonPayload)
125}
126
127// base64Encode encodes a byte array into a base64 URL safe string with no padding.
128func base64Encode(content []byte) string {
129	return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(content)
130}
131
132// base64Decode decodes a URL safe base64 encoded string into a byte array ignoring padding.
133func base64Decode(content string) ([]byte, error) {
134	for _, c := range content {
135		if !isValidURLsafeBase64Char(c) {
136			return nil, fmt.Errorf("invalid encoding")
137		}
138	}
139	return base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(content)
140}
141
142func isValidURLsafeBase64Char(c rune) bool {
143	return (((c >= 'a') && (c <= 'z')) || ((c >= 'A') && (c <= 'Z')) ||
144		((c >= '0') && (c <= '9')) || ((c == '-') || (c == '_')))
145}
146
147func dotConcat(a, b string) string {
148	return fmt.Sprintf("%s.%s", a, b)
149}
150
151func jsonToStruct(jsonPayload []byte) (*spb.Struct, error) {
152	payload := &spb.Struct{}
153	if err := payload.UnmarshalJSON(jsonPayload); err != nil {
154		return nil, err
155	}
156	return payload, nil
157}
158
159func extractTypeHeader(header *spb.Struct) (*string, error) {
160	fields := header.GetFields()
161	if fields == nil {
162		return nil, fmt.Errorf("header contains no fields")
163	}
164	val, ok := fields["typ"]
165	if !ok {
166		return nil, nil
167	}
168	str, ok := val.Kind.(*spb.Value_StringValue)
169	if !ok {
170		return nil, fmt.Errorf("type header isn't a string")
171	}
172	return &str.StringValue, nil
173}
174
175func createHeader(algorithm string, typeHeader, kid *string) (string, error) {
176	header := &spb.Struct{
177		Fields: map[string]*spb.Value{
178			"alg": spb.NewStringValue(algorithm),
179		},
180	}
181	if typeHeader != nil {
182		header.Fields["typ"] = spb.NewStringValue(*typeHeader)
183	}
184	if kid != nil {
185		header.Fields["kid"] = spb.NewStringValue(*kid)
186	}
187	jsonHeader, err := header.MarshalJSON()
188	if err != nil {
189		return "", err
190	}
191	return base64Encode(jsonHeader), nil
192}
193
194func validateHeader(header *spb.Struct, algorithm string, tinkKID, customKID *string) error {
195	fields := header.GetFields()
196	if fields == nil {
197		return fmt.Errorf("header contains no fields")
198	}
199	alg, err := headerStringField(fields, "alg")
200	if err != nil {
201		return err
202	}
203	if alg != algorithm {
204		return fmt.Errorf("invalid alg")
205	}
206	if _, ok := fields["crit"]; ok {
207		return fmt.Errorf("all tokens with crit headers are rejected")
208	}
209	if tinkKID != nil && customKID != nil {
210		return fmt.Errorf("custom_kid can only be set for RAW keys")
211	}
212	_, hasKID := fields["kid"]
213	if tinkKID != nil && !hasKID {
214		return fmt.Errorf("missing kid in header")
215	}
216	if tinkKID != nil {
217		return validateKIDInHeader(fields, tinkKID)
218	}
219	if hasKID && customKID != nil {
220		return validateKIDInHeader(fields, customKID)
221	}
222	return nil
223}
224
225func validateKIDInHeader(fields map[string]*spb.Value, kid *string) error {
226	headerKID, err := headerStringField(fields, "kid")
227	if err != nil {
228		return err
229	}
230	if headerKID != *kid {
231		return fmt.Errorf("invalid kid header")
232	}
233	return nil
234}
235
236func headerStringField(fields map[string]*spb.Value, name string) (string, error) {
237	val, ok := fields[name]
238	if !ok {
239		return "", fmt.Errorf("header is missing %q", name)
240	}
241	str, ok := val.Kind.(*spb.Value_StringValue)
242	if !ok {
243		return "", fmt.Errorf("%q header isn't a string", name)
244	}
245	return str.StringValue, nil
246}
247