xref: /aosp_15_r20/external/tink/testing/go/jwt_service_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15///////////////////////////////////////////////////////////////////////////////
16
17package services_test
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"testing"
24
25	dpb "google.golang.org/protobuf/types/known/durationpb"
26	spb "google.golang.org/protobuf/types/known/structpb"
27	tpb "google.golang.org/protobuf/types/known/timestamppb"
28	wpb "google.golang.org/protobuf/types/known/wrapperspb"
29	"github.com/google/go-cmp/cmp"
30	"google.golang.org/protobuf/proto"
31	"google.golang.org/protobuf/testing/protocmp"
32	"github.com/google/tink/go/aead"
33	"github.com/google/tink/go/jwt"
34	"github.com/google/tink/go/signature"
35	"github.com/google/tink/testing/go/services"
36	pb "github.com/google/tink/testing/go/protos/testing_api_go_grpc"
37)
38
39func verifiedJWTFromResponse(response *pb.JwtVerifyResponse) (*pb.JwtToken, error) {
40	switch r := response.Result.(type) {
41	case *pb.JwtVerifyResponse_VerifiedJwt:
42		return r.VerifiedJwt, nil
43	case *pb.JwtVerifyResponse_Err:
44		return nil, errors.New(r.Err)
45	default:
46		return nil, fmt.Errorf("response.Result has unexpected type %T", r)
47	}
48}
49
50func signedCompactJWTFromResponse(response *pb.JwtSignResponse) (string, error) {
51	switch r := response.Result.(type) {
52	case *pb.JwtSignResponse_SignedCompactJwt:
53		return r.SignedCompactJwt, nil
54	case *pb.JwtSignResponse_Err:
55		return "", errors.New(r.Err)
56	default:
57		return "", fmt.Errorf("response.Result has unexpected type %T", r)
58	}
59}
60
61func jwkSetFromResponse(response *pb.JwtToJwkSetResponse) (string, error) {
62	switch r := response.Result.(type) {
63	case *pb.JwtToJwkSetResponse_JwkSet:
64		return r.JwkSet, nil
65	case *pb.JwtToJwkSetResponse_Err:
66		return "", errors.New(r.Err)
67	default:
68		return "", fmt.Errorf("response.Result has unexpected type %T", r)
69	}
70}
71
72func keysetFromResponse(response *pb.JwtFromJwkSetResponse) ([]byte, error) {
73	switch r := response.Result.(type) {
74	case *pb.JwtFromJwkSetResponse_Keyset:
75		return r.Keyset, nil
76	case *pb.JwtFromJwkSetResponse_Err:
77		return nil, errors.New(r.Err)
78	default:
79		return nil, fmt.Errorf("response.Result has unexpected type %T", r)
80	}
81}
82
83type jwtTestCase struct {
84	tag       string
85	rawJWT    *pb.JwtToken
86	validator *pb.JwtValidator
87}
88
89func TestJWTComputeInvalidJWT(t *testing.T) {
90	for _, tc := range []jwtTestCase{
91		{
92			tag:    "nil rawJWT",
93			rawJWT: nil,
94		},
95		{
96			tag: "invalid json array string",
97			rawJWT: &pb.JwtToken{
98				CustomClaims: map[string]*pb.JwtClaimValue{
99					"cc-array": &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonArrayValue{JsonArrayValue: "{35}"}},
100				},
101			},
102		},
103		{
104			tag: "invalid json object string",
105			rawJWT: &pb.JwtToken{
106				CustomClaims: map[string]*pb.JwtClaimValue{
107					"cc-object": &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonObjectValue{JsonObjectValue: `["o":"a"]`}},
108				},
109			},
110		},
111	} {
112		t.Run(tc.tag, func(t *testing.T) {
113			keysetService := &services.KeysetService{}
114			jwtService := &services.JWTService{}
115			ctx := context.Background()
116			template, err := proto.Marshal(jwt.HS256Template())
117			if err != nil {
118				t.Fatalf("proto.Marshal(jwt.HS256Template()) failed: %v", err)
119			}
120			keyset, err := genKeyset(ctx, keysetService, template)
121			if err != nil {
122				t.Fatalf("genKeyset failed: %v", err)
123			}
124			signResponse, err := jwtService.ComputeMacAndEncode(ctx, &pb.JwtSignRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: keyset}, RawJwt: tc.rawJWT})
125			if err != nil {
126				t.Fatalf("jwtService.ComputeMacAndEncode() err = %v, want nil", err)
127			}
128			if _, err := signedCompactJWTFromResponse(signResponse); err == nil {
129				t.Fatalf("JwtSignResponse: error = nil, want error")
130			}
131		})
132	}
133}
134
135func TestSuccessfulJwtMacCreation(t *testing.T) {
136	keysetService := &services.KeysetService{}
137	jwtService := &services.JWTService{}
138	ctx := context.Background()
139
140	template, err := proto.Marshal(jwt.HS256Template())
141	if err != nil {
142		t.Fatalf("proto.Marshal(jwt.HS256Template()) failed: %v, want nil", err)
143	}
144
145	keyset, err := genKeyset(ctx, keysetService, template)
146	if err != nil {
147		t.Fatalf("genKeyset failed: %v", err)
148	}
149
150	result, err := jwtService.CreateJwtMac(ctx, &pb.CreationRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: keyset}})
151	if err != nil {
152		t.Fatalf("CreateJwtMac with good keyset failed with gRPC error: %v, want nil", err)
153	}
154	if result.GetErr() != "" {
155		t.Fatalf("CreateJwtMac with good keyset failed with result.GetErr() = %q, want empty string", result.GetErr())
156	}
157}
158
159func TestFailingJwtMacCreation(t *testing.T) {
160	keysetService := &services.KeysetService{}
161	jwtService := &services.JWTService{}
162	ctx := context.Background()
163
164	// We use signature keys -- then we cannot create a JwtMac
165	template, err := proto.Marshal(aead.AES128GCMKeyTemplate())
166	if err != nil {
167		t.Fatalf("proto.Marshal(signature.ECDSAP256KeyTemplate()) failed: %v", err)
168	}
169
170	badKeyset, err := genKeyset(ctx, keysetService, template)
171	if err != nil {
172		t.Fatalf("genKeyset failed: %v", err)
173	}
174
175	result, err := jwtService.CreateJwtMac(ctx, &pb.CreationRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: badKeyset}})
176	if err != nil {
177		t.Fatalf("CreateJwtMac with bad keyset failed with gRPC error: %v", err)
178	}
179	if result.GetErr() == "" {
180		t.Fatalf("result.GetErr() of bad keyset after CreateJwtMac is empty, want not empty")
181	}
182}
183
184func TestJWTComputeMACWithInvalidKeysetFails(t *testing.T) {
185	keysetService := &services.KeysetService{}
186	jwtService := &services.JWTService{}
187	ctx := context.Background()
188	template, err := proto.Marshal(aead.AES256GCMKeyTemplate())
189	if err != nil {
190		t.Fatalf("proto.Marshal(jwt.AES256GCMKeyTemplate()) failed: %v", err)
191	}
192	keyset, err := genKeyset(ctx, keysetService, template)
193	if err != nil {
194		t.Fatalf("genKeyset failed: %v", err)
195	}
196	rawJWT := &pb.JwtToken{
197		TypeHeader: &wpb.StringValue{Value: "JWT"},
198		Issuer:     &wpb.StringValue{Value: "issuer"},
199	}
200	signResponse, err := jwtService.ComputeMacAndEncode(ctx, &pb.JwtSignRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: keyset}, RawJwt: rawJWT})
201	if err != nil {
202		t.Fatalf("jwtService.ComputeMacAndEncode() err = %v, want nil", err)
203	}
204	if _, err := signedCompactJWTFromResponse(signResponse); err == nil {
205		t.Fatalf("JwtSignResponse: error = nil, want error")
206	}
207}
208
209func TestJWTComputeAndVerifyMac(t *testing.T) {
210	for _, tc := range []jwtTestCase{
211		{
212			tag: "all claims and custom claims",
213			rawJWT: &pb.JwtToken{
214				TypeHeader: &wpb.StringValue{Value: "JWT"},
215				Issuer:     &wpb.StringValue{Value: "issuer"},
216				Subject:    &wpb.StringValue{Value: "subject"},
217				JwtId:      &wpb.StringValue{Value: "tink"},
218				Audiences:  []string{"audience"},
219				Expiration: &tpb.Timestamp{Seconds: 123456},
220				NotBefore:  &tpb.Timestamp{Seconds: 12345},
221				IssuedAt:   &tpb.Timestamp{Seconds: 1234},
222				CustomClaims: map[string]*pb.JwtClaimValue{
223					"cc-null":   &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_NullValue{}},
224					"cc-num":    &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_NumberValue{NumberValue: 5.67}},
225					"cc-bool":   &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_BoolValue{BoolValue: true}},
226					"cc-string": &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_StringValue{StringValue: "foo bar"}},
227					"cc-array":  &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonArrayValue{JsonArrayValue: "[35]"}},
228					"cc-object": &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonObjectValue{JsonObjectValue: `{"key":"val"}`}},
229				},
230			},
231			validator: &pb.JwtValidator{
232				ExpectedTypeHeader: &wpb.StringValue{Value: "JWT"},
233				ExpectedIssuer:     &wpb.StringValue{Value: "issuer"},
234				ExpectedAudience:   &wpb.StringValue{Value: "audience"},
235				Now:                &tpb.Timestamp{Seconds: 12345},
236				ClockSkew:          &dpb.Duration{Seconds: 0},
237			},
238		},
239		{
240			tag: "without custom claims",
241			rawJWT: &pb.JwtToken{
242				TypeHeader: &wpb.StringValue{Value: "JWT"},
243				Issuer:     &wpb.StringValue{Value: "issuer"},
244				Subject:    &wpb.StringValue{Value: "subject"},
245				Audiences:  []string{"audience"},
246			},
247			validator: &pb.JwtValidator{
248				ExpectedTypeHeader:     &wpb.StringValue{Value: "JWT"},
249				ExpectedIssuer:         &wpb.StringValue{Value: "issuer"},
250				ExpectedAudience:       &wpb.StringValue{Value: "audience"},
251				AllowMissingExpiration: true,
252			},
253		},
254		{
255			tag: "without expiration",
256			rawJWT: &pb.JwtToken{
257				Subject: &wpb.StringValue{Value: "subject"},
258			},
259			validator: &pb.JwtValidator{
260				AllowMissingExpiration: true,
261			},
262		},
263		{
264			tag: "clock skew",
265			rawJWT: &pb.JwtToken{
266				Expiration: &tpb.Timestamp{Seconds: 1234},
267			},
268			validator: &pb.JwtValidator{
269				Now:       &tpb.Timestamp{Seconds: 1235},
270				ClockSkew: &dpb.Duration{Seconds: 2},
271			},
272		},
273	} {
274		t.Run(tc.tag, func(t *testing.T) {
275			keysetService := &services.KeysetService{}
276			jwtService := &services.JWTService{}
277			ctx := context.Background()
278			template, err := proto.Marshal(jwt.HS256Template())
279			if err != nil {
280				t.Fatalf("proto.Marshal(jwt.HS256Template()) failed: %v", err)
281			}
282			keyset, err := genKeyset(ctx, keysetService, template)
283			if err != nil {
284				t.Fatalf("genKeyset failed: %v", err)
285			}
286
287			signResponse, err := jwtService.ComputeMacAndEncode(ctx, &pb.JwtSignRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: keyset}, RawJwt: tc.rawJWT})
288			if err != nil {
289				t.Fatalf("jwtService.ComputeMacAndEncode() err = %v, want nil", err)
290			}
291			compact, err := signedCompactJWTFromResponse(signResponse)
292			if err != nil {
293				t.Fatalf("JwtSignResponse_Err: %v", err)
294			}
295			verifyResponse, err := jwtService.VerifyMacAndDecode(ctx, &pb.JwtVerifyRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: keyset}, SignedCompactJwt: compact, Validator: tc.validator})
296			if err != nil {
297				t.Fatalf("jwtService.VerifyMacAndDecode() err = %v, want nil", err)
298			}
299			verifiedJWT, err := verifiedJWTFromResponse(verifyResponse)
300			if err != nil {
301				t.Fatalf("JwtVerifyResponse_Err: %v", err)
302			}
303			if !cmp.Equal(verifiedJWT, tc.rawJWT, protocmp.Transform()) {
304				t.Errorf("verifiedJWT doesn't match expected value: (+ got, - want) %v", cmp.Diff(verifiedJWT, tc.rawJWT, protocmp.Transform()))
305			}
306		})
307	}
308}
309
310func TestJWTVerifyMACFailures(t *testing.T) {
311	keysetService := &services.KeysetService{}
312	jwtService := &services.JWTService{}
313	ctx := context.Background()
314	template, err := proto.Marshal(jwt.HS256Template())
315	if err != nil {
316		t.Fatalf("proto.Marshal(jwt.HS256Template()) failed: %v", err)
317	}
318	keyset, err := genKeyset(ctx, keysetService, template)
319	if err != nil {
320		t.Fatalf("genKeyset failed: %v", err)
321	}
322	rawJWT := &pb.JwtToken{
323		TypeHeader: &wpb.StringValue{Value: "JWT"},
324		Expiration: &tpb.Timestamp{Seconds: 123456},
325		NotBefore:  &tpb.Timestamp{Seconds: 12345},
326		IssuedAt:   &tpb.Timestamp{Seconds: 1234},
327	}
328	signResponse, err := jwtService.ComputeMacAndEncode(ctx, &pb.JwtSignRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: keyset}, RawJwt: rawJWT})
329	if err != nil {
330		t.Fatalf("jwtService.ComputeMacAndEncode() err = %v, want nil", err)
331	}
332	compact, err := signedCompactJWTFromResponse(signResponse)
333	if err != nil {
334		t.Fatalf("JwtSignResponse_Err: %v", err)
335	}
336	validator := &pb.JwtValidator{
337		ExpectedTypeHeader: &wpb.StringValue{Value: "JWT"},
338		Now:                &tpb.Timestamp{Seconds: 12345},
339	}
340	verifyResponse, err := jwtService.VerifyMacAndDecode(ctx, &pb.JwtVerifyRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: keyset}, SignedCompactJwt: compact, Validator: validator})
341	if err != nil {
342		t.Fatalf("jwtService.VerifyMacAndDecode() err = %v, want nil", err)
343	}
344	if _, err := verifiedJWTFromResponse(verifyResponse); err != nil {
345		t.Fatalf("JwtVerifyResponse_Err: %v", err)
346	}
347	for _, tc := range []jwtTestCase{
348		{
349			tag: "unexpected type header",
350			validator: &pb.JwtValidator{
351				ExpectedTypeHeader: &wpb.StringValue{Value: "unexpected"},
352				Now:                &tpb.Timestamp{Seconds: 12345},
353			},
354		},
355		{
356			tag: "expired token",
357			validator: &pb.JwtValidator{
358				ExpectedTypeHeader: &wpb.StringValue{Value: "JWT"},
359				Now:                &tpb.Timestamp{Seconds: 999999999999},
360			},
361		},
362		{
363			tag: "expect issued in the past",
364			validator: &pb.JwtValidator{
365				ExpectedTypeHeader:    &wpb.StringValue{Value: "JWT"},
366				Now:                   &tpb.Timestamp{Seconds: 1233},
367				ExpectIssuedInThePast: true,
368			},
369		},
370	} {
371		t.Run(tc.tag, func(t *testing.T) {
372			verifyResponse, err := jwtService.VerifyMacAndDecode(ctx, &pb.JwtVerifyRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: keyset}, SignedCompactJwt: compact, Validator: tc.validator})
373			if err != nil {
374				t.Fatalf("jwtService.VerifyMacAndDecode() err = %v, want nil", err)
375			}
376			if _, err := verifiedJWTFromResponse(verifyResponse); err == nil {
377				t.Fatalf("JwtVerifyResponse_Err: nil, want error")
378			}
379		})
380	}
381}
382
383func TestSuccessfulJwtSignVerifyCreation(t *testing.T) {
384	keysetService := &services.KeysetService{}
385	jwtService := &services.JWTService{}
386	ctx := context.Background()
387
388	template, err := proto.Marshal(jwt.ES256Template())
389	if err != nil {
390		t.Fatalf("proto.Marshal(hybrid.ES256Template()) failed: %v", err)
391	}
392
393	privateKeyset, err := genKeyset(ctx, keysetService, template)
394	if err != nil {
395		t.Fatalf("genKeyset failed: %v", err)
396	}
397
398	result, err := jwtService.CreateJwtPublicKeySign(ctx, &pb.CreationRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: privateKeyset}})
399	if err != nil {
400		t.Fatalf("CreateJwtPublicKeySign with good keyset failed with gRPC error: %v, want nil", err)
401	}
402	if result.GetErr() != "" {
403		t.Fatalf("CreateJwtPublicKeySign with good keyset failed with result.GetErr() = %q, want empty string", result.GetErr())
404	}
405}
406
407func TestSuccessfulJwtVerifyCreation(t *testing.T) {
408	keysetService := &services.KeysetService{}
409	jwtService := &services.JWTService{}
410	ctx := context.Background()
411
412	template, err := proto.Marshal(jwt.ES256Template())
413	if err != nil {
414		t.Fatalf("proto.Marshal(hybrid.ES256Template()) failed: %v", err)
415	}
416
417	privateKeyset, err := genKeyset(ctx, keysetService, template)
418	if err != nil {
419		t.Fatalf("genKeyset failed: %v", err)
420	}
421	publicKeyset, err := pubKeyset(ctx, keysetService, privateKeyset)
422	if err != nil {
423		t.Fatalf("pubKeyset failed: %v", err)
424	}
425
426	result, err := jwtService.CreateJwtPublicKeyVerify(ctx, &pb.CreationRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: publicKeyset}})
427	if err != nil {
428		t.Fatalf("CreateJwtPublicKeyVerify with good keyset failed with gRPC error: %v", err)
429	}
430	if result.GetErr() != "" {
431		t.Fatalf("CreateJwtPublicKeyVerify with good keyset failed with result.GetErr() = %q, want empty string", result.GetErr())
432	}
433}
434
435func TestFailingJwtSignCreation(t *testing.T) {
436	keysetService := &services.KeysetService{}
437	jwtService := &services.JWTService{}
438	ctx := context.Background()
439
440	// We use signature keys -- then we cannot create a hybrid encrypt
441	template, err := proto.Marshal(signature.ECDSAP256KeyTemplate())
442	if err != nil {
443		t.Fatalf("proto.Marshal(signature.ECDSAP256KeyTemplate()) failed: %v", err)
444	}
445
446	privateKeyset, err := genKeyset(ctx, keysetService, template)
447	if err != nil {
448		t.Fatalf("genKeyset failed: %v", err)
449	}
450
451	result, err := jwtService.CreateJwtPublicKeySign(ctx, &pb.CreationRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: privateKeyset}})
452	if err != nil {
453		t.Fatalf("CreateJwtPublicKeySign with bad keyset failed with gRPC error: %v", err)
454	}
455	if result.GetErr() == "" {
456		t.Fatalf("CreateJwtPublicKeySign with bad keyset succeeded")
457	}
458}
459
460func TestFailingJwtVerifyCreation(t *testing.T) {
461	keysetService := &services.KeysetService{}
462	jwtService := &services.JWTService{}
463	ctx := context.Background()
464
465	// We use signature keys -- then we cannot create a hybrid encrypt
466	template, err := proto.Marshal(signature.ECDSAP256KeyTemplate())
467	if err != nil {
468		t.Fatalf("proto.Marshal(signature.ECDSAP256KeyTemplate()) failed: %v", err)
469	}
470
471	privateKeyset, err := genKeyset(ctx, keysetService, template)
472	if err != nil {
473		t.Fatalf("genKeyset failed: %v", err)
474	}
475	publicKeyset, err := pubKeyset(ctx, keysetService, privateKeyset)
476	if err != nil {
477		t.Fatalf("pubKeyset failed: %v", err)
478	}
479
480	result, err := jwtService.CreateJwtPublicKeyVerify(ctx, &pb.CreationRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: publicKeyset}})
481	if err != nil {
482		t.Fatalf("CreateJwtPublicKeyVerify with good keyset failed with gRPC error: %v", err)
483	}
484	if result.GetErr() == "" {
485		t.Fatalf("CreateJwtPublicKeyVerify with bad keyset succeeded")
486	}
487}
488
489func TestJWTPublicKeySignWithInvalidKeysetFails(t *testing.T) {
490	keysetService := &services.KeysetService{}
491	jwtService := &services.JWTService{}
492
493	ctx := context.Background()
494	template, err := proto.Marshal(aead.AES256GCMKeyTemplate())
495	if err != nil {
496		t.Fatalf("proto.Marshal(aead.AES256GCMKeyTemplate()) failed: %v", err)
497	}
498	privateKeyset, err := genKeyset(ctx, keysetService, template)
499	if err != nil {
500		t.Fatalf("genKeyset failed: %v", err)
501	}
502	rawJWT := &pb.JwtToken{
503		Subject: &wpb.StringValue{Value: "tink-subject"},
504	}
505	signResponse, err := jwtService.PublicKeySignAndEncode(ctx, &pb.JwtSignRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: privateKeyset}, RawJwt: rawJWT})
506	if err != nil {
507		t.Fatalf("jwtService.PublicKeySignAndEncode() err = %v", err)
508	}
509	if _, err := signedCompactJWTFromResponse(signResponse); err == nil {
510		t.Fatalf("JwtSignResponse_Err: nil want error")
511	}
512}
513
514func TestJWTPublicKeySignInvalidTokenFails(t *testing.T) {
515	keysetService := &services.KeysetService{}
516	jwtService := &services.JWTService{}
517
518	ctx := context.Background()
519	template, err := proto.Marshal(jwt.ES256Template())
520	if err != nil {
521		t.Fatalf("proto.Marshal(jwt.ES256Template()) failed: %v", err)
522	}
523	privateKeyset, err := genKeyset(ctx, keysetService, template)
524	if err != nil {
525		t.Fatalf("genKeyset failed: %v", err)
526	}
527	for _, tc := range []jwtTestCase{
528		{
529			tag:    "nil rawJWT",
530			rawJWT: nil,
531		},
532		{
533			tag: "invalid json array string",
534			rawJWT: &pb.JwtToken{
535				CustomClaims: map[string]*pb.JwtClaimValue{
536					"cc-array": &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonArrayValue{JsonArrayValue: "{35}"}},
537				},
538			},
539		},
540		{
541			tag: "invalid json object string",
542			rawJWT: &pb.JwtToken{
543				CustomClaims: map[string]*pb.JwtClaimValue{
544					"cc-object": &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonObjectValue{JsonObjectValue: `["o":"a"]`}},
545				},
546			},
547		},
548	} {
549		t.Run(tc.tag, func(t *testing.T) {
550			signResponse, err := jwtService.PublicKeySignAndEncode(ctx, &pb.JwtSignRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: privateKeyset}, RawJwt: tc.rawJWT})
551			if err != nil {
552				t.Fatalf("jwtService.PublicKeySignAndEncode() err = %v", err)
553			}
554			if _, err := signedCompactJWTFromResponse(signResponse); err == nil {
555				t.Fatalf("JwtSignResponse_Err: nil want error")
556			}
557		})
558	}
559}
560
561func TestJWTPublicKeyVerifyFails(t *testing.T) {
562	keysetService := &services.KeysetService{}
563	jwtService := &services.JWTService{}
564
565	ctx := context.Background()
566	template, err := proto.Marshal(jwt.ES256Template())
567	if err != nil {
568		t.Fatalf("proto.Marshal(jwt.ES256Template()) failed: %v", err)
569	}
570	privateKeyset, err := genKeyset(ctx, keysetService, template)
571	if err != nil {
572		t.Fatalf("genKeyset failed: %v", err)
573	}
574	publicKeyset, err := pubKeyset(ctx, keysetService, privateKeyset)
575	if err != nil {
576		t.Fatalf("pubKeyset failed: %v", err)
577	}
578	rawJWT := &pb.JwtToken{
579		Subject: &wpb.StringValue{Value: "tink-subject"},
580	}
581	signResponse, err := jwtService.PublicKeySignAndEncode(ctx, &pb.JwtSignRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: privateKeyset}, RawJwt: rawJWT})
582	if err != nil {
583		t.Fatalf("jwtService.PublicKeySignAndEncode() err = %v", err)
584	}
585	compact, err := signedCompactJWTFromResponse(signResponse)
586	if err != nil {
587		t.Fatalf("JwtSignResponse_Err failed: %v", err)
588	}
589	validator := &pb.JwtValidator{
590		ExpectedTypeHeader: &wpb.StringValue{Value: "JWT"},
591	}
592	verifyResponse, err := jwtService.PublicKeyVerifyAndDecode(ctx, &pb.JwtVerifyRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: publicKeyset}, SignedCompactJwt: compact, Validator: validator})
593	if err != nil {
594		t.Fatalf("jwtVerifySignature failed: %v", err)
595	}
596	if _, err := verifiedJWTFromResponse(verifyResponse); err == nil {
597		t.Fatalf("JwtVerifyResponse_Err: nil want error")
598	}
599}
600
601func TestJWTPublicKeySignAndEncodeVerifyAndDecode(t *testing.T) {
602	keysetService := &services.KeysetService{}
603	jwtService := &services.JWTService{}
604
605	ctx := context.Background()
606	template, err := proto.Marshal(jwt.ES256Template())
607	if err != nil {
608		t.Fatalf("proto.Marshal(jwt.ES256Template()) failed: %v", err)
609	}
610	privateKeyset, err := genKeyset(ctx, keysetService, template)
611	if err != nil {
612		t.Fatalf("genKeyset failed: %v", err)
613	}
614	publicKeyset, err := pubKeyset(ctx, keysetService, privateKeyset)
615	if err != nil {
616		t.Fatalf("pubKeyset failed: %v", err)
617	}
618	rawJWT := &pb.JwtToken{
619		Subject: &wpb.StringValue{Value: "tink-subject"},
620	}
621	signResponse, err := jwtService.PublicKeySignAndEncode(ctx, &pb.JwtSignRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: privateKeyset}, RawJwt: rawJWT})
622	if err != nil {
623		t.Fatalf("jwtService.PublicKeySignAndEncode() err = %v", err)
624	}
625	compact, err := signedCompactJWTFromResponse(signResponse)
626	if err != nil {
627		t.Fatalf("JwtSignResponse_Err failed: %v", err)
628	}
629	validator := &pb.JwtValidator{
630		AllowMissingExpiration: true,
631	}
632	verifyResponse, err := jwtService.PublicKeyVerifyAndDecode(ctx, &pb.JwtVerifyRequest{AnnotatedKeyset: &pb.AnnotatedKeyset{SerializedKeyset: publicKeyset}, SignedCompactJwt: compact, Validator: validator})
633	if err != nil {
634		t.Fatalf("jwtVerifySignature failed: %v", err)
635	}
636	verifiedJWT, err := verifiedJWTFromResponse(verifyResponse)
637	if err != nil {
638		t.Fatalf("JwtVerifyResponse_Err: %v", err)
639	}
640	if !cmp.Equal(verifiedJWT, rawJWT, protocmp.Transform()) {
641		t.Errorf("verifiedJWT doesn't match expected value: (+ got, - want) %v", cmp.Diff(verifiedJWT, rawJWT, protocmp.Transform()))
642	}
643}
644
645func TestToJwkSetWithPrivateKeyFails(t *testing.T) {
646	keysetService := &services.KeysetService{}
647	jwtService := &services.JWTService{}
648
649	ctx := context.Background()
650	template, err := proto.Marshal(jwt.ES256Template())
651	if err != nil {
652		t.Fatalf("proto.Marshal(jwt.ES256Template()) failed: %v", err)
653	}
654	privateKeyset, err := genKeyset(ctx, keysetService, template)
655	if err != nil {
656		t.Fatalf("genKeyset failed: %v", err)
657	}
658	toJWKResponse, err := jwtService.ToJwkSet(ctx, &pb.JwtToJwkSetRequest{Keyset: privateKeyset})
659	if err != nil {
660		t.Fatalf("jwtService.ToJwkSet() err = %v, want nil", err)
661	}
662	if _, err := jwkSetFromResponse(toJWKResponse); err == nil {
663		t.Fatalf("JwtToJwkSetResponse_Err: = nil, want error")
664	}
665}
666
667func TestFromJwkSetPrivateKeyFails(t *testing.T) {
668	jwtService := &services.JWTService{}
669	ctx := context.Background()
670	jwkES256PublicKey := `{
671	  "keys":[{
672	  "kty":"EC",
673	  "crv":"P-256",
674	  "x":"wO6uIxh8SkKOO8VjZXNRTteRcwCPE4_4JElKyaa0fcQ",
675	  "y":"7oRiYhnmkP6nqrdXWgtsWUWq5uFRLJkhyVFiWPRB278",
676		"d":"8oRinhnmkYjkqrdXWgtsWUWq5uFRLJkhyVFiWPRB278",
677	  "use":"sig","alg":"ES256","key_ops":["verify"],
678	  "kid":"EhuduQ"}]
679	}`
680	fromJWKResponse, err := jwtService.FromJwkSet(ctx, &pb.JwtFromJwkSetRequest{JwkSet: jwkES256PublicKey})
681	if err != nil {
682		t.Fatalf("jwtService.FromJwkSet() err = %v, want nil", err)
683	}
684	if _, err := keysetFromResponse(fromJWKResponse); err == nil {
685		t.Fatalf("JwtFromJwkSetResponse_Err = nil, want error")
686	}
687}
688
689func TestFromJwkToJwkSet(t *testing.T) {
690	jwtService := &services.JWTService{}
691	ctx := context.Background()
692	jwkES256PublicKey := `{
693	  "keys":[{
694	  "kty":"EC",
695	  "crv":"P-256",
696	  "x":"wO6uIxh8SkKOO8VjZXNRTteRcwCPE4_4JElKyaa0fcQ",
697	  "y":"7oRiYhnmkP6nqrdXWgtsWUWq5uFRLJkhyVFiWPRB278",
698	  "use":"sig","alg":"ES256","key_ops":["verify"],
699	  "kid":"EhuduQ"}]
700	}`
701	fromJWKResponse, err := jwtService.FromJwkSet(ctx, &pb.JwtFromJwkSetRequest{JwkSet: jwkES256PublicKey})
702	if err != nil {
703		t.Fatalf("jwtService.FromJwkSet() err = %v, want nil", err)
704	}
705	ks, err := keysetFromResponse(fromJWKResponse)
706	if err != nil {
707		t.Fatalf("JwtFromJwkSetResponse_Err: = %v, want nil", err)
708	}
709	toJWKResponse, err := jwtService.ToJwkSet(ctx, &pb.JwtToJwkSetRequest{Keyset: ks})
710	if err != nil {
711		t.Fatalf("jwtService.ToJwkSet() err = %v, want nil", err)
712	}
713	jwkSet, err := jwkSetFromResponse(toJWKResponse)
714	if err != nil {
715		t.Fatalf("JwtToJwkSetResponse_Err: = %v, want nil", err)
716	}
717	got := &spb.Struct{}
718	if err := got.UnmarshalJSON([]byte(jwkSet)); err != nil {
719		t.Fatalf("got.UnmarshalJSON() err = %v, want nil", err)
720	}
721	want := &spb.Struct{}
722	if err := want.UnmarshalJSON([]byte(jwkES256PublicKey)); err != nil {
723		t.Fatalf("want.UnmarshalJSON() err = %v, want nil", err)
724	}
725	if !cmp.Equal(want, got, protocmp.Transform()) {
726		t.Errorf("mismatch in jwk sets: diff (-want,+got): %v", cmp.Diff(want, got, protocmp.Transform()))
727	}
728}
729