1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package jwt 18 19import ( 20 "encoding/base64" 21 "testing" 22 "time" 23 24 "github.com/google/go-cmp/cmp" 25 "google.golang.org/protobuf/proto" 26 "github.com/google/tink/go/core/registry" 27 "github.com/google/tink/go/subtle/random" 28 jwtmacpb "github.com/google/tink/go/proto/jwt_hmac_go_proto" 29 tinkpb "github.com/google/tink/go/proto/tink_go_proto" 30) 31 32type jwtKeyManagerTestCase struct { 33 tag string 34 keyFormat *jwtmacpb.JwtHmacKeyFormat 35 key *jwtmacpb.JwtHmacKey 36} 37 38const ( 39 typeURL = "type.googleapis.com/google.crypto.tink.JwtHmacKey" 40) 41 42func generateKeyFormat(keySize uint32, algorithm jwtmacpb.JwtHmacAlgorithm) *jwtmacpb.JwtHmacKeyFormat { 43 return &jwtmacpb.JwtHmacKeyFormat{ 44 KeySize: keySize, 45 Algorithm: algorithm, 46 } 47} 48 49func TestDoesSupport(t *testing.T) { 50 km, err := registry.GetKeyManager(typeURL) 51 if err != nil { 52 t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err) 53 } 54 if !km.DoesSupport(typeURL) { 55 t.Errorf("km.DoesSupport(%q) = false, want true", typeURL) 56 } 57} 58 59func TestTypeURL(t *testing.T) { 60 km, err := registry.GetKeyManager(typeURL) 61 if err != nil { 62 t.Errorf("registry.GetKeyManager(%q) error = %v, want nil", typeURL, err) 63 } 64 if km.TypeURL() != typeURL { 65 t.Errorf("km.TypeURL() = %q, want %q", km.TypeURL(), typeURL) 66 } 67} 68 69var invalidKeyFormatTestCases = []jwtKeyManagerTestCase{ 70 { 71 tag: "invalid hash algorithm", 72 keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN), 73 }, 74 { 75 tag: "invalid HS256 key size", 76 keyFormat: generateKeyFormat(31, jwtmacpb.JwtHmacAlgorithm_HS256), 77 }, 78 { 79 tag: "invalid HS384 key size", 80 keyFormat: generateKeyFormat(47, jwtmacpb.JwtHmacAlgorithm_HS384), 81 }, 82 { 83 tag: "invalid HS512 key size", 84 keyFormat: generateKeyFormat(63, jwtmacpb.JwtHmacAlgorithm_HS512), 85 }, 86 { 87 tag: "empty key format", 88 keyFormat: &jwtmacpb.JwtHmacKeyFormat{}, 89 }, 90 { 91 tag: "nil key format", 92 keyFormat: nil, 93 }, 94} 95 96func TestNewKeyInvalidFormatFails(t *testing.T) { 97 for _, tc := range invalidKeyFormatTestCases { 98 t.Run(tc.tag, func(t *testing.T) { 99 km, err := registry.GetKeyManager(typeURL) 100 if err != nil { 101 t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) 102 } 103 serializedKeyFormat, err := proto.Marshal(tc.keyFormat) 104 if err != nil { 105 t.Errorf("serializing key format: %v", err) 106 } 107 if _, err := km.NewKey(serializedKeyFormat); err == nil { 108 t.Errorf("km.NewKey() err = nil, want error") 109 } 110 }) 111 } 112} 113 114func TestNewDataInvalidFormatFails(t *testing.T) { 115 for _, tc := range invalidKeyFormatTestCases { 116 t.Run(tc.tag, func(t *testing.T) { 117 km, err := registry.GetKeyManager(typeURL) 118 if err != nil { 119 t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) 120 } 121 serializedKeyFormat, err := proto.Marshal(tc.keyFormat) 122 if err != nil { 123 t.Errorf("serializing key format: %v", err) 124 } 125 if _, err := km.NewKeyData(serializedKeyFormat); err == nil { 126 t.Errorf("km.NewKey() err = nil, want error") 127 } 128 }) 129 } 130} 131 132var validKeyFormatTestCases = []jwtKeyManagerTestCase{ 133 { 134 tag: "SHA256 hash algorithm", 135 keyFormat: generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256), 136 }, 137 { 138 tag: "SHA384 hash algorithm", 139 keyFormat: generateKeyFormat(48, jwtmacpb.JwtHmacAlgorithm_HS384), 140 }, 141 { 142 tag: "SHA512 hash algorithm", 143 keyFormat: generateKeyFormat(64, jwtmacpb.JwtHmacAlgorithm_HS512), 144 }, 145} 146 147func TestNewKey(t *testing.T) { 148 for _, tc := range validKeyFormatTestCases { 149 t.Run(tc.tag, func(t *testing.T) { 150 km, err := registry.GetKeyManager(typeURL) 151 if err != nil { 152 t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) 153 } 154 serializedKeyFormat, err := proto.Marshal(tc.keyFormat) 155 if err != nil { 156 t.Errorf("serializing key format: %v", err) 157 } 158 k, err := km.NewKey(serializedKeyFormat) 159 if err != nil { 160 t.Errorf("km.NewKey() err = %v, want nil", err) 161 } 162 key, ok := k.(*jwtmacpb.JwtHmacKey) 163 if !ok { 164 t.Errorf("key isn't of type JwtHmacKey") 165 } 166 if key.Algorithm != tc.keyFormat.Algorithm { 167 t.Errorf("k.Algorithm = %v, want %v", key.Algorithm, tc.keyFormat.Algorithm) 168 } 169 if len(key.KeyValue) != int(tc.keyFormat.KeySize) { 170 t.Errorf("len(key.KeyValue) = %d, want %d", len(key.KeyValue), tc.keyFormat.KeySize) 171 } 172 }) 173 } 174} 175 176func TestNewKeyData(t *testing.T) { 177 for _, tc := range validKeyFormatTestCases { 178 t.Run(tc.tag, func(t *testing.T) { 179 km, err := registry.GetKeyManager(typeURL) 180 if err != nil { 181 t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) 182 } 183 serializedKeyFormat, err := proto.Marshal(tc.keyFormat) 184 if err != nil { 185 t.Errorf("serializing key format: %v", err) 186 } 187 k, err := km.NewKeyData(serializedKeyFormat) 188 if err != nil { 189 t.Errorf("km.NewKeyData() err = %v, want nil", err) 190 } 191 if k.GetTypeUrl() != typeURL { 192 t.Errorf("k.GetTypeUrl() = %q, want %q", k.GetTypeUrl(), typeURL) 193 } 194 if k.GetKeyMaterialType() != tinkpb.KeyData_SYMMETRIC { 195 t.Errorf("k.GetKeyMaterialType() = %q, want %q", k.GetKeyMaterialType(), tinkpb.KeyData_SYMMETRIC) 196 } 197 }) 198 } 199} 200 201func generateKey(keySize, version uint32, algorithm jwtmacpb.JwtHmacAlgorithm, kid *jwtmacpb.JwtHmacKey_CustomKid) *jwtmacpb.JwtHmacKey { 202 return &jwtmacpb.JwtHmacKey{ 203 KeyValue: random.GetRandomBytes(keySize), 204 Algorithm: algorithm, 205 CustomKid: kid, 206 Version: version, 207 } 208} 209 210func TestGetPrimitiveWithValidKeys(t *testing.T) { 211 rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true, Audiences: []string{"tink-aud"}}) 212 if err != nil { 213 t.Fatalf("NewRawJWT() err = %v, want nil", err) 214 } 215 validator, err := NewValidator(&ValidatorOpts{AllowMissingExpiration: true, ExpectedAudience: refString("tink-aud")}) 216 if err != nil { 217 t.Fatalf("NewValidator() err = %v, want nil", err) 218 } 219 for _, tc := range []jwtKeyManagerTestCase{ 220 { 221 tag: "SHA256 hash algorithm", 222 key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil), 223 }, 224 { 225 tag: "SHA384 hash algorithm", 226 key: generateKey(48, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), 227 }, 228 { 229 tag: "SHA512 hash algorithm", 230 key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil), 231 }, 232 { 233 tag: "with custom kid", 234 key: generateKey(64, 0, jwtmacpb.JwtHmacAlgorithm_HS512, &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"}), 235 }, 236 } { 237 t.Run(tc.tag, func(t *testing.T) { 238 km, err := registry.GetKeyManager(typeURL) 239 if err != nil { 240 t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) 241 } 242 serializedKey, err := proto.Marshal(tc.key) 243 if err != nil { 244 t.Errorf("serializing key format: %v", err) 245 } 246 p, err := km.Primitive(serializedKey) 247 if err != nil { 248 t.Errorf("km.Primitive() err = %v, want nil", err) 249 } 250 primitive, ok := p.(*macWithKID) 251 if !ok { 252 t.Errorf("primitive isn't of type: macWithKID") 253 } 254 compact, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, nil) 255 if err != nil { 256 t.Errorf("ComputeMACAndEncodeWithKID() err = %v, want nil", err) 257 } 258 verifiedJWT, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil) 259 if err != nil { 260 t.Errorf("VerifyMACAndDecodeWithKID() err = %v, want nil", err) 261 } 262 audiences, err := verifiedJWT.Audiences() 263 if err != nil { 264 t.Errorf("verifiedJWT.Audiences() err = %v, want nil", err) 265 } 266 if !cmp.Equal(audiences, []string{"tink-aud"}) { 267 t.Errorf("verifiedJWT.Audiences() = %q, want ['tink-aud']", audiences) 268 } 269 270 }) 271 } 272} 273 274func TestGetPrimitiveWithInvalidKeys(t *testing.T) { 275 for _, tc := range []jwtKeyManagerTestCase{ 276 { 277 tag: "HS256", 278 key: generateKey(31, 0, jwtmacpb.JwtHmacAlgorithm_HS256, nil), 279 }, 280 { 281 tag: "HS384", 282 key: generateKey(47, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), 283 }, 284 { 285 tag: "HS512", 286 key: generateKey(63, 0, jwtmacpb.JwtHmacAlgorithm_HS512, nil), 287 }, 288 } { 289 t.Run(tc.tag, func(t *testing.T) { 290 km, err := registry.GetKeyManager(typeURL) 291 if err != nil { 292 t.Fatalf("registry.GetKeyManager(%q) err=%q, want nil", typeURL, err) 293 } 294 serializedKey, err := proto.Marshal(tc.key) 295 if err != nil { 296 t.Fatalf("proto.Marshal(tc.key) err =%q, want nil", err) 297 } 298 _, err = km.Primitive(serializedKey) 299 if err == nil { 300 t.Error("km.Primitive(serializedKey) err = nil, want error") 301 } 302 }) 303 } 304} 305 306func TestSpecyfingCustomKIDAndTINKKIDFails(t *testing.T) { 307 // key and compact are examples from: https://datatracker.ietf.org/doc/html/rfc7515#appendix-A.1.1 308 compact := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" 309 rawKey, err := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString("AyM1SysPpbyDfgZld3umj1qzKObwVMkoqQ-EstJQLr_T-1qS0gZH75aKtMN3Yj0iPS4hcgUuTwjAzZr1Z9CAow") 310 if err != nil { 311 t.Fatalf("failed decoding test key: %v", err) 312 } 313 key := &jwtmacpb.JwtHmacKey{ 314 KeyValue: rawKey, 315 Algorithm: jwtmacpb.JwtHmacAlgorithm_HS256, 316 CustomKid: &jwtmacpb.JwtHmacKey_CustomKid{Value: "1235"}, 317 Version: 0, 318 } 319 km, err := registry.GetKeyManager(typeURL) 320 if err != nil { 321 t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) 322 } 323 serializedKey, err := proto.Marshal(key) 324 if err != nil { 325 t.Errorf("serializing key format: %v", err) 326 } 327 p, err := km.Primitive(serializedKey) 328 if err != nil { 329 t.Errorf("km.Primitive() err = %v, want nil", err) 330 } 331 primitive, ok := p.(*macWithKID) 332 if !ok { 333 t.Errorf("primitive isn't of type: macWithKID") 334 } 335 336 rawJWT, err := NewRawJWT(&RawJWTOptions{WithoutExpiration: true}) 337 if err != nil { 338 t.Errorf("creating new RawJWT: %v", err) 339 } 340 opts := &ValidatorOpts{ 341 ExpectedTypeHeader: refString("JWT"), 342 ExpectedIssuer: refString("joe"), 343 FixedNow: time.Unix(12345, 0), 344 } 345 validator, err := NewValidator(opts) 346 if err != nil { 347 t.Errorf("creating new JWTValidator: %v", err) 348 } 349 if _, err := primitive.ComputeMACAndEncodeWithKID(rawJWT, refString("4566")); err == nil { 350 t.Errorf("primitive.ComputeMACAndEncodeWithKID() err = nil, want error") 351 } 352 if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, refString("4566")); err == nil { 353 t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = 4566) err = nil, want error") 354 } 355 // Verify success without KID 356 if _, err := primitive.VerifyMACAndDecodeWithKID(compact, validator, nil); err != nil { 357 t.Errorf("primitive.VerifyMACAndDecodeWithKID(kid = nil) err = %v, want nil", err) 358 } 359} 360 361func TestGetPrimitiveWithInvalidKeyFails(t *testing.T) { 362 for _, tc := range []jwtKeyManagerTestCase{ 363 { 364 tag: "empty key", 365 key: &jwtmacpb.JwtHmacKey{}, 366 }, 367 { 368 tag: "nil key", 369 key: nil, 370 }, 371 { 372 tag: "unsupported hash algorithm", 373 key: generateKey(32, 0, jwtmacpb.JwtHmacAlgorithm_HS_UNKNOWN, nil), 374 }, 375 { 376 tag: "short key length", 377 key: generateKey(20, 0, jwtmacpb.JwtHmacAlgorithm_HS384, nil), 378 }, 379 { 380 tag: "unsupported version", 381 key: generateKey(48, 1, jwtmacpb.JwtHmacAlgorithm_HS384, nil), 382 }, 383 } { 384 t.Run(tc.tag, func(t *testing.T) { 385 km, err := registry.GetKeyManager(typeURL) 386 if err != nil { 387 t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) 388 } 389 serializedKey, err := proto.Marshal(tc.key) 390 if err != nil { 391 t.Errorf("serializing key format: %v", err) 392 } 393 if _, err := km.Primitive(serializedKey); err == nil { 394 t.Errorf("km.Primitive() err = nil, want error") 395 } 396 }) 397 } 398} 399 400func TestGeneratesDifferentKeys(t *testing.T) { 401 km, err := registry.GetKeyManager(typeURL) 402 if err != nil { 403 t.Errorf("registry.GetKeyManager(%q): %v", typeURL, err) 404 } 405 serializedKeyFormat, err := proto.Marshal(generateKeyFormat(32, jwtmacpb.JwtHmacAlgorithm_HS256)) 406 if err != nil { 407 t.Errorf("serializing key format: %v", err) 408 } 409 k1, err := km.NewKey(serializedKeyFormat) 410 if err != nil { 411 t.Errorf("km.NewKey() err = %v, want nil", err) 412 } 413 k2, err := km.NewKey(serializedKeyFormat) 414 if err != nil { 415 t.Errorf("km.NewKey() err = %v, want nil", err) 416 } 417 key1, ok := k1.(*jwtmacpb.JwtHmacKey) 418 if !ok { 419 t.Errorf("k1 isn't of type JwtHmacKey") 420 } 421 key2, ok := k2.(*jwtmacpb.JwtHmacKey) 422 if !ok { 423 t.Errorf("k2 isn't of type JwtHmacKey") 424 } 425 if cmp.Equal(key1.GetKeyValue(), key2.GetKeyValue()) { 426 t.Errorf("key material should differ") 427 } 428} 429