1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15/////////////////////////////////////////////////////////////////////////////// 16 17package services 18 19import ( 20 "bytes" 21 "context" 22 "fmt" 23 "time" 24 25 spb "google.golang.org/protobuf/types/known/structpb" 26 tpb "google.golang.org/protobuf/types/known/timestamppb" 27 wpb "google.golang.org/protobuf/types/known/wrapperspb" 28 "github.com/google/tink/go/insecurecleartextkeyset" 29 "github.com/google/tink/go/jwt" 30 "github.com/google/tink/go/keyset" 31 pb "github.com/google/tink/testing/go/protos/testing_api_go_grpc" 32) 33 34// JWTService implements the JWT testing service. 35type JWTService struct { 36 pb.JwtServer 37} 38 39func (s *JWTService) CreateJwtMac(ctx context.Context, req *pb.CreationRequest) (*pb.CreationResponse, error) { 40 handle, err := toKeysetHandle(req.GetAnnotatedKeyset()) 41 if err != nil { 42 return &pb.CreationResponse{Err: err.Error()}, nil 43 } 44 _, err = jwt.NewMAC(handle) 45 if err != nil { 46 return &pb.CreationResponse{Err: err.Error()}, nil 47 } 48 return &pb.CreationResponse{}, nil 49} 50 51func (s *JWTService) CreateJwtPublicKeySign(ctx context.Context, req *pb.CreationRequest) (*pb.CreationResponse, error) { 52 handle, err := toKeysetHandle(req.GetAnnotatedKeyset()) 53 if err != nil { 54 return &pb.CreationResponse{Err: err.Error()}, nil 55 } 56 _, err = jwt.NewSigner(handle) 57 if err != nil { 58 return &pb.CreationResponse{Err: err.Error()}, nil 59 } 60 return &pb.CreationResponse{}, nil 61} 62 63func (s *JWTService) CreateJwtPublicKeyVerify(ctx context.Context, req *pb.CreationRequest) (*pb.CreationResponse, error) { 64 handle, err := toKeysetHandle(req.GetAnnotatedKeyset()) 65 if err != nil { 66 return &pb.CreationResponse{Err: err.Error()}, nil 67 } 68 _, err = jwt.NewVerifier(handle) 69 if err != nil { 70 return &pb.CreationResponse{Err: err.Error()}, nil 71 } 72 return &pb.CreationResponse{}, nil 73} 74 75func refString(s *wpb.StringValue) *string { 76 if s == nil { 77 return nil 78 } 79 v := s.GetValue() 80 return &v 81} 82 83func refTime(t *tpb.Timestamp) *time.Time { 84 if t == nil { 85 return nil 86 } 87 v := time.Unix(t.GetSeconds(), 0) 88 return &v 89} 90 91func arrayClaimToJSONString(array []interface{}) (string, error) { 92 lv, err := spb.NewList(array) 93 if err != nil { 94 return "", err 95 } 96 b, err := lv.MarshalJSON() 97 if err != nil { 98 return "", err 99 } 100 return string(b), nil 101} 102 103func jsonStringToArrayClaim(stringArray string) ([]interface{}, error) { 104 s := spb.NewListValue(&spb.ListValue{}) 105 if err := s.UnmarshalJSON([]byte(stringArray)); err != nil { 106 return nil, err 107 } 108 if s.GetListValue() == nil { 109 return nil, fmt.Errorf("invalid list") 110 } 111 return s.GetListValue().AsSlice(), nil 112} 113 114func objectClaimToJSONString(o map[string]interface{}) (string, error) { 115 s, err := spb.NewStruct(o) 116 if err != nil { 117 return "", err 118 } 119 b, err := s.MarshalJSON() 120 if err != nil { 121 return "", err 122 } 123 return string(b), nil 124} 125 126func jsonStringToObjectClaim(obj string) (map[string]interface{}, error) { 127 s := &spb.Struct{} 128 if err := s.UnmarshalJSON([]byte(obj)); err != nil { 129 return nil, err 130 } 131 return s.AsMap(), nil 132} 133 134func customClaimsFromProto(cc map[string]*pb.JwtClaimValue) (map[string]interface{}, error) { 135 r := map[string]interface{}{} 136 for key, val := range cc { 137 switch val.Kind.(type) { 138 case *pb.JwtClaimValue_NullValue: 139 r[key] = nil 140 case *pb.JwtClaimValue_StringValue: 141 r[key] = val.GetStringValue() 142 case *pb.JwtClaimValue_NumberValue: 143 r[key] = val.GetNumberValue() 144 case *pb.JwtClaimValue_BoolValue: 145 r[key] = val.GetBoolValue() 146 case *pb.JwtClaimValue_JsonArrayValue: 147 a, err := jsonStringToArrayClaim(val.GetJsonArrayValue()) 148 if err != nil { 149 return nil, err 150 } 151 r[key] = a 152 case *pb.JwtClaimValue_JsonObjectValue: 153 o, err := jsonStringToObjectClaim(val.GetJsonObjectValue()) 154 if err != nil { 155 return nil, err 156 } 157 r[key] = o 158 default: 159 return nil, fmt.Errorf("unsupported type") 160 } 161 } 162 return r, nil 163} 164 165func tokenFromProto(t *pb.JwtToken) (*jwt.RawJWT, error) { 166 if t == nil { 167 return nil, nil 168 } 169 ccs, err := customClaimsFromProto(t.GetCustomClaims()) 170 if err != nil { 171 return nil, err 172 } 173 opts := &jwt.RawJWTOptions{ 174 TypeHeader: refString(t.GetTypeHeader()), 175 Audiences: t.GetAudiences(), 176 Subject: refString(t.GetSubject()), 177 Issuer: refString(t.GetIssuer()), 178 JWTID: refString(t.GetJwtId()), 179 IssuedAt: refTime(t.GetIssuedAt()), 180 NotBefore: refTime(t.GetNotBefore()), 181 ExpiresAt: refTime(t.GetExpiration()), 182 CustomClaims: ccs, 183 } 184 if opts.ExpiresAt == nil { 185 opts.WithoutExpiration = true 186 } 187 return jwt.NewRawJWT(opts) 188} 189 190func toStringValue(present bool, getValue func() (string, error), val **wpb.StringValue) error { 191 if !present { 192 return nil 193 } 194 v, err := getValue() 195 if err != nil { 196 return err 197 } 198 *val = &wpb.StringValue{Value: v} 199 return nil 200} 201 202func toTimeValue(present bool, getValue func() (time.Time, error), val **tpb.Timestamp) error { 203 if !present { 204 return nil 205 } 206 v, err := getValue() 207 if err != nil { 208 return err 209 } 210 *val = &tpb.Timestamp{Seconds: v.Unix()} 211 return nil 212} 213 214func tokenToProto(v *jwt.VerifiedJWT) (*pb.JwtToken, error) { 215 t := &pb.JwtToken{ 216 CustomClaims: map[string]*pb.JwtClaimValue{}, 217 } 218 if err := toStringValue(v.HasTypeHeader(), v.TypeHeader, &t.TypeHeader); err != nil { 219 return nil, err 220 } 221 if err := toStringValue(v.HasIssuer(), v.Issuer, &t.Issuer); err != nil { 222 return nil, err 223 } 224 if err := toStringValue(v.HasSubject(), v.Subject, &t.Subject); err != nil { 225 return nil, err 226 } 227 if err := toStringValue(v.HasJWTID(), v.JWTID, &t.JwtId); err != nil { 228 return nil, err 229 } 230 if err := toTimeValue(v.HasExpiration(), v.ExpiresAt, &t.Expiration); err != nil { 231 return nil, err 232 } 233 if err := toTimeValue(v.HasIssuedAt(), v.IssuedAt, &t.IssuedAt); err != nil { 234 return nil, err 235 } 236 if err := toTimeValue(v.HasNotBefore(), v.NotBefore, &t.NotBefore); err != nil { 237 return nil, err 238 } 239 if v.HasAudiences() { 240 aud, err := v.Audiences() 241 if err != nil { 242 return nil, err 243 } 244 t.Audiences = aud 245 } 246 247 for _, name := range v.CustomClaimNames() { 248 if v.HasArrayClaim(name) { 249 array, err := v.ArrayClaim(name) 250 if err != nil { 251 return nil, err 252 } 253 s, err := arrayClaimToJSONString(array) 254 if err != nil { 255 return nil, err 256 } 257 t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonArrayValue{JsonArrayValue: s}} 258 continue 259 } 260 if v.HasObjectClaim(name) { 261 m, err := v.ObjectClaim(name) 262 if err != nil { 263 return nil, err 264 } 265 o, err := objectClaimToJSONString(m) 266 if err != nil { 267 return nil, err 268 } 269 t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_JsonObjectValue{JsonObjectValue: o}} 270 continue 271 } 272 if v.HasNullClaim(name) { 273 t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_NullValue{}} 274 continue 275 } 276 if v.HasStringClaim(name) { 277 s, err := v.StringClaim(name) 278 if err != nil { 279 return nil, err 280 } 281 t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_StringValue{StringValue: s}} 282 continue 283 } 284 if v.HasBooleanClaim(name) { 285 b, err := v.BooleanClaim(name) 286 if err != nil { 287 return nil, err 288 } 289 t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_BoolValue{BoolValue: b}} 290 continue 291 } 292 if v.HasNumberClaim(name) { 293 n, err := v.NumberClaim(name) 294 if err != nil { 295 return nil, err 296 } 297 t.CustomClaims[name] = &pb.JwtClaimValue{Kind: &pb.JwtClaimValue_NumberValue{NumberValue: n}} 298 continue 299 } 300 return nil, fmt.Errorf("claim %q of unsupported type", name) 301 } 302 303 return t, nil 304} 305 306func validatorFromProto(v *pb.JwtValidator) (*jwt.Validator, error) { 307 fixedNow := time.Now() 308 if v.GetNow() != nil { 309 fixedNow = *refTime(v.GetNow()) 310 } 311 opts := &jwt.ValidatorOpts{ 312 ExpectedTypeHeader: refString(v.GetExpectedTypeHeader()), 313 ExpectedAudience: refString(v.GetExpectedAudience()), 314 ExpectedIssuer: refString(v.GetExpectedIssuer()), 315 ExpectIssuedInThePast: v.GetExpectIssuedInThePast(), 316 AllowMissingExpiration: v.GetAllowMissingExpiration(), 317 IgnoreTypeHeader: v.GetIgnoreTypeHeader(), 318 IgnoreAudiences: v.GetIgnoreAudience(), 319 IgnoreIssuer: v.GetIgnoreIssuer(), 320 FixedNow: fixedNow, 321 ClockSkew: time.Duration(v.GetClockSkew().GetSeconds()) * time.Second, 322 } 323 return jwt.NewValidator(opts) 324} 325 326func jwtSignResponseError(err error) *pb.JwtSignResponse { 327 return &pb.JwtSignResponse{ 328 Result: &pb.JwtSignResponse_Err{err.Error()}} 329} 330 331func jwtVerifyResponseError(err error) *pb.JwtVerifyResponse { 332 return &pb.JwtVerifyResponse{ 333 Result: &pb.JwtVerifyResponse_Err{err.Error()}} 334} 335 336func jwtToJWKSetResponseError(err error) *pb.JwtToJwkSetResponse { 337 return &pb.JwtToJwkSetResponse{ 338 Result: &pb.JwtToJwkSetResponse_Err{err.Error()}} 339} 340 341func jwtFromJwkSetResponseError(err error) *pb.JwtFromJwkSetResponse { 342 return &pb.JwtFromJwkSetResponse{ 343 Result: &pb.JwtFromJwkSetResponse_Err{err.Error()}} 344} 345 346func (s *JWTService) ComputeMacAndEncode(ctx context.Context, req *pb.JwtSignRequest) (*pb.JwtSignResponse, error) { 347 handle, err := toKeysetHandle(req.GetAnnotatedKeyset()) 348 if err != nil { 349 return jwtSignResponseError(err), nil 350 } 351 primitive, err := jwt.NewMAC(handle) 352 if err != nil { 353 return jwtSignResponseError(err), nil 354 } 355 rawJWT, err := tokenFromProto(req.GetRawJwt()) 356 if err != nil { 357 return jwtSignResponseError(err), nil 358 } 359 compact, err := primitive.ComputeMACAndEncode(rawJWT) 360 if err != nil { 361 return jwtSignResponseError(err), nil 362 } 363 return &pb.JwtSignResponse{ 364 Result: &pb.JwtSignResponse_SignedCompactJwt{compact}, 365 }, nil 366} 367 368func (s *JWTService) VerifyMacAndDecode(ctx context.Context, req *pb.JwtVerifyRequest) (*pb.JwtVerifyResponse, error) { 369 handle, err := toKeysetHandle(req.GetAnnotatedKeyset()) 370 if err != nil { 371 return jwtVerifyResponseError(err), nil 372 } 373 primitive, err := jwt.NewMAC(handle) 374 if err != nil { 375 return jwtVerifyResponseError(err), nil 376 } 377 validator, err := validatorFromProto(req.GetValidator()) 378 if err != nil { 379 return jwtVerifyResponseError(err), nil 380 } 381 verified, err := primitive.VerifyMACAndDecode(req.GetSignedCompactJwt(), validator) 382 if err != nil { 383 return jwtVerifyResponseError(err), nil 384 } 385 verifiedJWT, err := tokenToProto(verified) 386 if err != nil { 387 return jwtVerifyResponseError(err), nil 388 } 389 return &pb.JwtVerifyResponse{ 390 Result: &pb.JwtVerifyResponse_VerifiedJwt{verifiedJWT}, 391 }, nil 392} 393 394func (s *JWTService) PublicKeySignAndEncode(ctx context.Context, req *pb.JwtSignRequest) (*pb.JwtSignResponse, error) { 395 handle, err := toKeysetHandle(req.GetAnnotatedKeyset()) 396 if err != nil { 397 return jwtSignResponseError(err), nil 398 } 399 signer, err := jwt.NewSigner(handle) 400 if err != nil { 401 return jwtSignResponseError(err), nil 402 } 403 rawJWT, err := tokenFromProto(req.GetRawJwt()) 404 if err != nil { 405 return jwtSignResponseError(err), nil 406 } 407 compact, err := signer.SignAndEncode(rawJWT) 408 if err != nil { 409 return jwtSignResponseError(err), nil 410 } 411 return &pb.JwtSignResponse{ 412 Result: &pb.JwtSignResponse_SignedCompactJwt{compact}, 413 }, nil 414} 415 416func (s *JWTService) PublicKeyVerifyAndDecode(ctx context.Context, req *pb.JwtVerifyRequest) (*pb.JwtVerifyResponse, error) { 417 418 handle, err := toKeysetHandle(req.GetAnnotatedKeyset()) 419 if err != nil { 420 return jwtVerifyResponseError(err), nil 421 } 422 verifier, err := jwt.NewVerifier(handle) 423 if err != nil { 424 return jwtVerifyResponseError(err), nil 425 } 426 validator, err := validatorFromProto(req.GetValidator()) 427 if err != nil { 428 return jwtVerifyResponseError(err), nil 429 } 430 verified, err := verifier.VerifyAndDecode(req.GetSignedCompactJwt(), validator) 431 if err != nil { 432 return jwtVerifyResponseError(err), nil 433 } 434 verifiedJWT, err := tokenToProto(verified) 435 if err != nil { 436 return jwtVerifyResponseError(err), nil 437 } 438 return &pb.JwtVerifyResponse{ 439 Result: &pb.JwtVerifyResponse_VerifiedJwt{verifiedJWT}, 440 }, nil 441} 442 443func (s *JWTService) ToJwkSet(ctx context.Context, req *pb.JwtToJwkSetRequest) (*pb.JwtToJwkSetResponse, error) { 444 ks, err := keyset.NewBinaryReader(bytes.NewReader(req.GetKeyset())).Read() 445 if err != nil { 446 return jwtToJWKSetResponseError(err), nil 447 } 448 handle, err := keyset.NewHandleWithNoSecrets(ks) 449 if err != nil { 450 return jwtToJWKSetResponseError(err), nil 451 } 452 jwkSet, err := jwt.JWKSetFromPublicKeysetHandle(handle) 453 if err != nil { 454 return jwtToJWKSetResponseError(err), nil 455 } 456 return &pb.JwtToJwkSetResponse{ 457 Result: &pb.JwtToJwkSetResponse_JwkSet{string(jwkSet)}, 458 }, nil 459} 460 461func (s *JWTService) FromJwkSet(ctx context.Context, req *pb.JwtFromJwkSetRequest) (*pb.JwtFromJwkSetResponse, error) { 462 handle, err := jwt.JWKSetToPublicKeysetHandle([]byte(req.GetJwkSet())) 463 if err != nil { 464 return jwtFromJwkSetResponseError(err), nil 465 } 466 b := &bytes.Buffer{} 467 if err := insecurecleartextkeyset.Write(handle, keyset.NewBinaryWriter(b)); err != nil { 468 return jwtFromJwkSetResponseError(err), nil 469 } 470 return &pb.JwtFromJwkSetResponse{ 471 Result: &pb.JwtFromJwkSetResponse_Keyset{b.Bytes()}, 472 }, nil 473} 474