xref: /aosp_15_r20/external/tink/testing/go/jwt_service.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15///////////////////////////////////////////////////////////////////////////////
16
17package services
18
19import (
20	"bytes"
21	"context"
22	"fmt"
23	"time"
24
25	spb "google.golang.org/protobuf/types/known/structpb"
26	tpb "google.golang.org/protobuf/types/known/timestamppb"
27	wpb "google.golang.org/protobuf/types/known/wrapperspb"
28	"github.com/google/tink/go/insecurecleartextkeyset"
29	"github.com/google/tink/go/jwt"
30	"github.com/google/tink/go/keyset"
31	pb "github.com/google/tink/testing/go/protos/testing_api_go_grpc"
32)
33
34// JWTService implements the JWT testing service.
35type JWTService struct {
36	pb.JwtServer
37}
38
39func (s *JWTService) CreateJwtMac(ctx context.Context, req *pb.CreationRequest) (*pb.CreationResponse, error) {
40	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
41	if err != nil {
42		return &pb.CreationResponse{Err: err.Error()}, nil
43	}
44	_, err = jwt.NewMAC(handle)
45	if err != nil {
46		return &pb.CreationResponse{Err: err.Error()}, nil
47	}
48	return &pb.CreationResponse{}, nil
49}
50
51func (s *JWTService) CreateJwtPublicKeySign(ctx context.Context, req *pb.CreationRequest) (*pb.CreationResponse, error) {
52	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
53	if err != nil {
54		return &pb.CreationResponse{Err: err.Error()}, nil
55	}
56	_, err = jwt.NewSigner(handle)
57	if err != nil {
58		return &pb.CreationResponse{Err: err.Error()}, nil
59	}
60	return &pb.CreationResponse{}, nil
61}
62
63func (s *JWTService) CreateJwtPublicKeyVerify(ctx context.Context, req *pb.CreationRequest) (*pb.CreationResponse, error) {
64	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
65	if err != nil {
66		return &pb.CreationResponse{Err: err.Error()}, nil
67	}
68	_, err = jwt.NewVerifier(handle)
69	if err != nil {
70		return &pb.CreationResponse{Err: err.Error()}, nil
71	}
72	return &pb.CreationResponse{}, nil
73}
74
75func refString(s *wpb.StringValue) *string {
76	if s == nil {
77		return nil
78	}
79	v := s.GetValue()
80	return &v
81}
82
83func refTime(t *tpb.Timestamp) *time.Time {
84	if t == nil {
85		return nil
86	}
87	v := time.Unix(t.GetSeconds(), 0)
88	return &v
89}
90
91func arrayClaimToJSONString(array []interface{}) (string, error) {
92	lv, err := spb.NewList(array)
93	if err != nil {
94		return "", err
95	}
96	b, err := lv.MarshalJSON()
97	if err != nil {
98		return "", err
99	}
100	return string(b), nil
101}
102
103func jsonStringToArrayClaim(stringArray string) ([]interface{}, error) {
104	s := spb.NewListValue(&spb.ListValue{})
105	if err := s.UnmarshalJSON([]byte(stringArray)); err != nil {
106		return nil, err
107	}
108	if s.GetListValue() == nil {
109		return nil, fmt.Errorf("invalid list")
110	}
111	return s.GetListValue().AsSlice(), nil
112}
113
114func objectClaimToJSONString(o map[string]interface{}) (string, error) {
115	s, err := spb.NewStruct(o)
116	if err != nil {
117		return "", err
118	}
119	b, err := s.MarshalJSON()
120	if err != nil {
121		return "", err
122	}
123	return string(b), nil
124}
125
126func jsonStringToObjectClaim(obj string) (map[string]interface{}, error) {
127	s := &spb.Struct{}
128	if err := s.UnmarshalJSON([]byte(obj)); err != nil {
129		return nil, err
130	}
131	return s.AsMap(), nil
132}
133
134func customClaimsFromProto(cc map[string]*pb.JwtClaimValue) (map[string]interface{}, error) {
135	r := map[string]interface{}{}
136	for key, val := range cc {
137		switch val.Kind.(type) {
138		case *pb.JwtClaimValue_NullValue:
139			r[key] = nil
140		case *pb.JwtClaimValue_StringValue:
141			r[key] = val.GetStringValue()
142		case *pb.JwtClaimValue_NumberValue:
143			r[key] = val.GetNumberValue()
144		case *pb.JwtClaimValue_BoolValue:
145			r[key] = val.GetBoolValue()
146		case *pb.JwtClaimValue_JsonArrayValue:
147			a, err := jsonStringToArrayClaim(val.GetJsonArrayValue())
148			if err != nil {
149				return nil, err
150			}
151			r[key] = a
152		case *pb.JwtClaimValue_JsonObjectValue:
153			o, err := jsonStringToObjectClaim(val.GetJsonObjectValue())
154			if err != nil {
155				return nil, err
156			}
157			r[key] = o
158		default:
159			return nil, fmt.Errorf("unsupported type")
160		}
161	}
162	return r, nil
163}
164
165func tokenFromProto(t *pb.JwtToken) (*jwt.RawJWT, error) {
166	if t == nil {
167		return nil, nil
168	}
169	ccs, err := customClaimsFromProto(t.GetCustomClaims())
170	if err != nil {
171		return nil, err
172	}
173	opts := &jwt.RawJWTOptions{
174		TypeHeader:   refString(t.GetTypeHeader()),
175		Audiences:    t.GetAudiences(),
176		Subject:      refString(t.GetSubject()),
177		Issuer:       refString(t.GetIssuer()),
178		JWTID:        refString(t.GetJwtId()),
179		IssuedAt:     refTime(t.GetIssuedAt()),
180		NotBefore:    refTime(t.GetNotBefore()),
181		ExpiresAt:    refTime(t.GetExpiration()),
182		CustomClaims: ccs,
183	}
184	if opts.ExpiresAt == nil {
185		opts.WithoutExpiration = true
186	}
187	return jwt.NewRawJWT(opts)
188}
189
190func toStringValue(present bool, getValue func() (string, error), val **wpb.StringValue) error {
191	if !present {
192		return nil
193	}
194	v, err := getValue()
195	if err != nil {
196		return err
197	}
198	*val = &wpb.StringValue{Value: v}
199	return nil
200}
201
202func toTimeValue(present bool, getValue func() (time.Time, error), val **tpb.Timestamp) error {
203	if !present {
204		return nil
205	}
206	v, err := getValue()
207	if err != nil {
208		return err
209	}
210	*val = &tpb.Timestamp{Seconds: v.Unix()}
211	return nil
212}
213
214func tokenToProto(v *jwt.VerifiedJWT) (*pb.JwtToken, error) {
215	t := &pb.JwtToken{
216		CustomClaims: map[string]*pb.JwtClaimValue{},
217	}
218	if err := toStringValue(v.HasTypeHeader(), v.TypeHeader, &t.TypeHeader); err != nil {
219		return nil, err
220	}
221	if err := toStringValue(v.HasIssuer(), v.Issuer, &t.Issuer); err != nil {
222		return nil, err
223	}
224	if err := toStringValue(v.HasSubject(), v.Subject, &t.Subject); err != nil {
225		return nil, err
226	}
227	if err := toStringValue(v.HasJWTID(), v.JWTID, &t.JwtId); err != nil {
228		return nil, err
229	}
230	if err := toTimeValue(v.HasExpiration(), v.ExpiresAt, &t.Expiration); err != nil {
231		return nil, err
232	}
233	if err := toTimeValue(v.HasIssuedAt(), v.IssuedAt, &t.IssuedAt); err != nil {
234		return nil, err
235	}
236	if err := toTimeValue(v.HasNotBefore(), v.NotBefore, &t.NotBefore); err != nil {
237		return nil, err
238	}
239	if v.HasAudiences() {
240		aud, err := v.Audiences()
241		if err != nil {
242			return nil, err
243		}
244		t.Audiences = aud
245	}
246
247	for _, name := range v.CustomClaimNames() {
248		if v.HasArrayClaim(name) {
249			array, err := v.ArrayClaim(name)
250			if err != nil {
251				return nil, err
252			}
253			s, err := arrayClaimToJSONString(array)
254			if err != nil {
255				return nil, err
256			}
257			t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonArrayValue{JsonArrayValue: s}}
258			continue
259		}
260		if v.HasObjectClaim(name) {
261			m, err := v.ObjectClaim(name)
262			if err != nil {
263				return nil, err
264			}
265			o, err := objectClaimToJSONString(m)
266			if err != nil {
267				return nil, err
268			}
269			t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonObjectValue{JsonObjectValue: o}}
270			continue
271		}
272		if v.HasNullClaim(name) {
273			t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_NullValue{}}
274			continue
275		}
276		if v.HasStringClaim(name) {
277			s, err := v.StringClaim(name)
278			if err != nil {
279				return nil, err
280			}
281			t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_StringValue{StringValue: s}}
282			continue
283		}
284		if v.HasBooleanClaim(name) {
285			b, err := v.BooleanClaim(name)
286			if err != nil {
287				return nil, err
288			}
289			t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_BoolValue{BoolValue: b}}
290			continue
291		}
292		if v.HasNumberClaim(name) {
293			n, err := v.NumberClaim(name)
294			if err != nil {
295				return nil, err
296			}
297			t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_NumberValue{NumberValue: n}}
298			continue
299		}
300		return nil, fmt.Errorf("claim %q of unsupported type", name)
301	}
302
303	return t, nil
304}
305
306func validatorFromProto(v *pb.JwtValidator) (*jwt.Validator, error) {
307	fixedNow := time.Now()
308	if v.GetNow() != nil {
309		fixedNow = *refTime(v.GetNow())
310	}
311	opts := &jwt.ValidatorOpts{
312		ExpectedTypeHeader:     refString(v.GetExpectedTypeHeader()),
313		ExpectedAudience:       refString(v.GetExpectedAudience()),
314		ExpectedIssuer:         refString(v.GetExpectedIssuer()),
315		ExpectIssuedInThePast:  v.GetExpectIssuedInThePast(),
316		AllowMissingExpiration: v.GetAllowMissingExpiration(),
317		IgnoreTypeHeader:       v.GetIgnoreTypeHeader(),
318		IgnoreAudiences:        v.GetIgnoreAudience(),
319		IgnoreIssuer:           v.GetIgnoreIssuer(),
320		FixedNow:               fixedNow,
321		ClockSkew:              time.Duration(v.GetClockSkew().GetSeconds()) * time.Second,
322	}
323	return jwt.NewValidator(opts)
324}
325
326func jwtSignResponseError(err error) *pb.JwtSignResponse {
327	return &pb.JwtSignResponse{
328		Result: &pb.JwtSignResponse_Err{err.Error()}}
329}
330
331func jwtVerifyResponseError(err error) *pb.JwtVerifyResponse {
332	return &pb.JwtVerifyResponse{
333		Result: &pb.JwtVerifyResponse_Err{err.Error()}}
334}
335
336func jwtToJWKSetResponseError(err error) *pb.JwtToJwkSetResponse {
337	return &pb.JwtToJwkSetResponse{
338		Result: &pb.JwtToJwkSetResponse_Err{err.Error()}}
339}
340
341func jwtFromJwkSetResponseError(err error) *pb.JwtFromJwkSetResponse {
342	return &pb.JwtFromJwkSetResponse{
343		Result: &pb.JwtFromJwkSetResponse_Err{err.Error()}}
344}
345
346func (s *JWTService) ComputeMacAndEncode(ctx context.Context, req *pb.JwtSignRequest) (*pb.JwtSignResponse, error) {
347	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
348	if err != nil {
349		return jwtSignResponseError(err), nil
350	}
351	primitive, err := jwt.NewMAC(handle)
352	if err != nil {
353		return jwtSignResponseError(err), nil
354	}
355	rawJWT, err := tokenFromProto(req.GetRawJwt())
356	if err != nil {
357		return jwtSignResponseError(err), nil
358	}
359	compact, err := primitive.ComputeMACAndEncode(rawJWT)
360	if err != nil {
361		return jwtSignResponseError(err), nil
362	}
363	return &pb.JwtSignResponse{
364		Result: &pb.JwtSignResponse_SignedCompactJwt{compact},
365	}, nil
366}
367
368func (s *JWTService) VerifyMacAndDecode(ctx context.Context, req *pb.JwtVerifyRequest) (*pb.JwtVerifyResponse, error) {
369	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
370	if err != nil {
371		return jwtVerifyResponseError(err), nil
372	}
373	primitive, err := jwt.NewMAC(handle)
374	if err != nil {
375		return jwtVerifyResponseError(err), nil
376	}
377	validator, err := validatorFromProto(req.GetValidator())
378	if err != nil {
379		return jwtVerifyResponseError(err), nil
380	}
381	verified, err := primitive.VerifyMACAndDecode(req.GetSignedCompactJwt(), validator)
382	if err != nil {
383		return jwtVerifyResponseError(err), nil
384	}
385	verifiedJWT, err := tokenToProto(verified)
386	if err != nil {
387		return jwtVerifyResponseError(err), nil
388	}
389	return &pb.JwtVerifyResponse{
390		Result: &pb.JwtVerifyResponse_VerifiedJwt{verifiedJWT},
391	}, nil
392}
393
394func (s *JWTService) PublicKeySignAndEncode(ctx context.Context, req *pb.JwtSignRequest) (*pb.JwtSignResponse, error) {
395	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
396	if err != nil {
397		return jwtSignResponseError(err), nil
398	}
399	signer, err := jwt.NewSigner(handle)
400	if err != nil {
401		return jwtSignResponseError(err), nil
402	}
403	rawJWT, err := tokenFromProto(req.GetRawJwt())
404	if err != nil {
405		return jwtSignResponseError(err), nil
406	}
407	compact, err := signer.SignAndEncode(rawJWT)
408	if err != nil {
409		return jwtSignResponseError(err), nil
410	}
411	return &pb.JwtSignResponse{
412		Result: &pb.JwtSignResponse_SignedCompactJwt{compact},
413	}, nil
414}
415
416func (s *JWTService) PublicKeyVerifyAndDecode(ctx context.Context, req *pb.JwtVerifyRequest) (*pb.JwtVerifyResponse, error) {
417
418	handle, err := toKeysetHandle(req.GetAnnotatedKeyset())
419	if err != nil {
420		return jwtVerifyResponseError(err), nil
421	}
422	verifier, err := jwt.NewVerifier(handle)
423	if err != nil {
424		return jwtVerifyResponseError(err), nil
425	}
426	validator, err := validatorFromProto(req.GetValidator())
427	if err != nil {
428		return jwtVerifyResponseError(err), nil
429	}
430	verified, err := verifier.VerifyAndDecode(req.GetSignedCompactJwt(), validator)
431	if err != nil {
432		return jwtVerifyResponseError(err), nil
433	}
434	verifiedJWT, err := tokenToProto(verified)
435	if err != nil {
436		return jwtVerifyResponseError(err), nil
437	}
438	return &pb.JwtVerifyResponse{
439		Result: &pb.JwtVerifyResponse_VerifiedJwt{verifiedJWT},
440	}, nil
441}
442
443func (s *JWTService) ToJwkSet(ctx context.Context, req *pb.JwtToJwkSetRequest) (*pb.JwtToJwkSetResponse, error) {
444	ks, err := keyset.NewBinaryReader(bytes.NewReader(req.GetKeyset())).Read()
445	if err != nil {
446		return jwtToJWKSetResponseError(err), nil
447	}
448	handle, err := keyset.NewHandleWithNoSecrets(ks)
449	if err != nil {
450		return jwtToJWKSetResponseError(err), nil
451	}
452	jwkSet, err := jwt.JWKSetFromPublicKeysetHandle(handle)
453	if err != nil {
454		return jwtToJWKSetResponseError(err), nil
455	}
456	return &pb.JwtToJwkSetResponse{
457		Result: &pb.JwtToJwkSetResponse_JwkSet{string(jwkSet)},
458	}, nil
459}
460
461func (s *JWTService) FromJwkSet(ctx context.Context, req *pb.JwtFromJwkSetRequest) (*pb.JwtFromJwkSetResponse, error) {
462	handle, err := jwt.JWKSetToPublicKeysetHandle([]byte(req.GetJwkSet()))
463	if err != nil {
464		return jwtFromJwkSetResponseError(err), nil
465	}
466	b := &bytes.Buffer{}
467	if err := insecurecleartextkeyset.Write(handle, keyset.NewBinaryWriter(b)); err != nil {
468		return jwtFromJwkSetResponseError(err), nil
469	}
470	return &pb.JwtFromJwkSetResponse{
471		Result: &pb.JwtFromJwkSetResponse_Keyset{b.Bytes()},
472	}, nil
473}
474