1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ////////////////////////////////////////////////////////////////////////////////
16 
17 package com.google.crypto.tink.jwt;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.junit.Assert.assertThrows;
21 
22 import com.google.crypto.tink.InsecureSecretKeyAccess;
23 import com.google.crypto.tink.KeyTemplate;
24 import com.google.crypto.tink.KeyTemplates;
25 import com.google.crypto.tink.KeysetHandle;
26 import com.google.crypto.tink.KeysetManager;
27 import com.google.crypto.tink.Parameters;
28 import com.google.crypto.tink.TinkProtoKeysetFormat;
29 import com.google.crypto.tink.proto.Keyset;
30 import com.google.crypto.tink.proto.OutputPrefixType;
31 import com.google.crypto.tink.testing.TestUtil;
32 import com.google.protobuf.ExtensionRegistryLite;
33 import java.security.GeneralSecurityException;
34 import java.time.Clock;
35 import java.time.Instant;
36 import java.time.temporal.ChronoUnit;
37 import org.junit.Before;
38 import org.junit.Test;
39 import org.junit.experimental.theories.DataPoints;
40 import org.junit.experimental.theories.FromDataPoints;
41 import org.junit.experimental.theories.Theories;
42 import org.junit.experimental.theories.Theory;
43 import org.junit.runner.RunWith;
44 
45 /** Tests for JwtSignKeyverifyWrapper. */
46 @RunWith(Theories.class)
47 public class JwtPublicKeySignVerifyWrappersTest {
48 
49   @DataPoints("templateNames")
50   public static final String[] TEMPLATE_NAMES =
51       new String[] {
52         "JWT_ES256",
53         "JWT_ES384",
54         "JWT_ES512",
55         "JWT_ES256_RAW",
56         "JWT_RS256_2048_F4",
57         "JWT_RS256_3072_F4",
58         "JWT_RS384_3072_F4",
59         "JWT_RS512_4096_F4",
60         "JWT_RS256_2048_F4_RAW",
61         "JWT_PS256_2048_F4",
62         "JWT_PS256_3072_F4",
63         "JWT_PS384_3072_F4",
64         "JWT_PS512_4096_F4",
65         "JWT_PS256_2048_F4_RAW",
66       };
67 
68   @Before
setUp()69   public void setUp() throws GeneralSecurityException {
70     JwtSignatureConfig.register();
71   }
72 
73   @Test
test_noPrimary_getSignPrimitive_fails()74   public void test_noPrimary_getSignPrimitive_fails() throws Exception {
75     // The old KeysetManager API allows keysets without primary key.
76     // The KeysetHandle.Builder does not allow this and can't be used in this test.
77     KeyTemplate template = KeyTemplates.get("JWT_ES256");
78     KeysetManager manager = KeysetManager.withEmptyKeyset().add(template);
79     KeysetHandle handle = manager.getKeysetHandle();
80     assertThrows(
81         GeneralSecurityException.class, () -> handle.getPrimitive(JwtPublicKeySign.class));
82   }
83 
84   @Test
test_noPrimary_getVerifyPrimitive_success()85   public void test_noPrimary_getVerifyPrimitive_success() throws Exception {
86     KeysetHandle privateKeysetHandle =
87         KeysetHandle.newBuilder()
88             .addEntry(
89                 KeysetHandle.generateEntryFromParametersName("JWT_ES256")
90                     .withRandomId()
91                     .makePrimary())
92             .build();
93     KeysetHandle publicHandle = privateKeysetHandle.getPublicKeysetHandle();
94     Object unused = publicHandle.getPrimitive(JwtPublicKeyVerify.class);
95   }
96 
97   @Test
test_wrapLegacy_throws()98   public void test_wrapLegacy_throws() throws Exception {
99     KeysetHandle handle = KeysetHandle.generateNew(KeyTemplates.get("JWT_ES256_RAW"));
100     Keyset keyset =
101         Keyset.parseFrom(
102             TinkProtoKeysetFormat.serializeKeyset(handle, InsecureSecretKeyAccess.get()),
103             ExtensionRegistryLite.getEmptyRegistry());
104     Keyset.Builder legacyKeysetBuilder = keyset.toBuilder();
105     legacyKeysetBuilder.setKey(
106         0, legacyKeysetBuilder.getKey(0).toBuilder().setOutputPrefixType(OutputPrefixType.LEGACY));
107     KeysetHandle legacyHandle =
108         TinkProtoKeysetFormat.parseKeyset(
109             legacyKeysetBuilder.build().toByteArray(), InsecureSecretKeyAccess.get());
110     assertThrows(
111         GeneralSecurityException.class, () -> legacyHandle.getPrimitive(JwtPublicKeySign.class));
112 
113     KeysetHandle publicHandle = legacyHandle.getPublicKeysetHandle();
114     assertThrows(
115         GeneralSecurityException.class, () -> publicHandle.getPrimitive(JwtPublicKeyVerify.class));
116   }
117 
118   @Test
test_wrapSingleTinkKey_works()119   public void test_wrapSingleTinkKey_works() throws Exception {
120     KeyTemplate tinkTemplate = KeyTemplates.get("JWT_ES256");
121 
122     KeysetHandle handle = KeysetHandle.generateNew(tinkTemplate);
123 
124     JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
125     JwtPublicKeyVerify verifier =
126         handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
127     RawJwt rawToken = RawJwt.newBuilder().setJwtId("blah").withoutExpiration().build();
128     String signedCompact = signer.signAndEncode(rawToken);
129     JwtValidator validator = JwtValidator.newBuilder().allowMissingExpiration().build();
130     VerifiedJwt verifiedToken = verifier.verifyAndDecode(signedCompact, validator);
131     assertThat(verifiedToken.getJwtId()).isEqualTo("blah");
132   }
133 
134   @Test
test_wrapSingleRawKey_works()135   public void test_wrapSingleRawKey_works() throws Exception {
136     KeyTemplate template = KeyTemplates.get("JWT_ES256_RAW");
137     KeysetHandle handle = KeysetHandle.generateNew(template);
138 
139     JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
140     JwtPublicKeyVerify verifier =
141         handle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
142     RawJwt rawToken = RawJwt.newBuilder().setJwtId("blah").withoutExpiration().build();
143     String signedCompact = signer.signAndEncode(rawToken);
144     JwtValidator validator = JwtValidator.newBuilder().allowMissingExpiration().build();
145     VerifiedJwt verifiedToken = verifier.verifyAndDecode(signedCompact, validator);
146     assertThat(verifiedToken.getJwtId()).isEqualTo("blah");
147   }
148 
149   @Test
test_wrapMultipleRawKeys()150   public void test_wrapMultipleRawKeys() throws Exception {
151     KeysetHandle oldHandle =
152         KeysetHandle.newBuilder()
153             .addEntry(
154                 KeysetHandle.generateEntryFromParametersName("JWT_ES256_RAW")
155                     .withRandomId()
156                     .makePrimary())
157             .build();
158     KeysetHandle newHandle =
159         KeysetHandle.newBuilder(oldHandle)
160             .addEntry(
161                 KeysetHandle.generateEntryFromParametersName("JWT_ES256_RAW")
162                     .withRandomId()
163                     .makePrimary())
164             .build();
165 
166     JwtPublicKeySign oldSigner = oldHandle.getPrimitive(JwtPublicKeySign.class);
167     JwtPublicKeySign newSigner = newHandle.getPrimitive(JwtPublicKeySign.class);
168 
169     JwtPublicKeyVerify oldVerifier =
170         oldHandle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
171     JwtPublicKeyVerify newVerifier =
172         newHandle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
173 
174     RawJwt rawToken = RawJwt.newBuilder().setJwtId("jwtId").withoutExpiration().build();
175     String oldSignedCompact = oldSigner.signAndEncode(rawToken);
176     String newSignedCompact = newSigner.signAndEncode(rawToken);
177 
178     JwtValidator validator = JwtValidator.newBuilder().allowMissingExpiration().build();
179     assertThat(oldVerifier.verifyAndDecode(oldSignedCompact, validator).getJwtId())
180         .isEqualTo("jwtId");
181     assertThat(newVerifier.verifyAndDecode(oldSignedCompact, validator).getJwtId())
182         .isEqualTo("jwtId");
183     assertThat(newVerifier.verifyAndDecode(newSignedCompact, validator).getJwtId())
184         .isEqualTo("jwtId");
185     assertThrows(
186         GeneralSecurityException.class,
187         () -> oldVerifier.verifyAndDecode(newSignedCompact, validator));
188   }
189 
190   @Test
test_wrapMultipleTinkKeys()191   public void test_wrapMultipleTinkKeys() throws Exception {
192     KeysetHandle oldHandle =
193         KeysetHandle.newBuilder()
194             .addEntry(
195                 KeysetHandle.generateEntryFromParametersName("JWT_ES256")
196                     .withRandomId()
197                     .makePrimary())
198             .build();
199     KeysetHandle newHandle =
200         KeysetHandle.newBuilder(oldHandle)
201             .addEntry(
202                 KeysetHandle.generateEntryFromParametersName("JWT_ES256")
203                     .withRandomId()
204                     .makePrimary())
205             .build();
206 
207     JwtPublicKeySign oldSigner = oldHandle.getPrimitive(JwtPublicKeySign.class);
208     JwtPublicKeySign newSigner = newHandle.getPrimitive(JwtPublicKeySign.class);
209 
210     JwtPublicKeyVerify oldVerifier =
211         oldHandle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
212     JwtPublicKeyVerify newVerifier =
213         newHandle.getPublicKeysetHandle().getPrimitive(JwtPublicKeyVerify.class);
214 
215     RawJwt rawToken = RawJwt.newBuilder().setJwtId("jwtId").withoutExpiration().build();
216     String oldSignedCompact = oldSigner.signAndEncode(rawToken);
217     String newSignedCompact = newSigner.signAndEncode(rawToken);
218 
219     JwtValidator validator = JwtValidator.newBuilder().allowMissingExpiration().build();
220     assertThat(oldVerifier.verifyAndDecode(oldSignedCompact, validator).getJwtId())
221         .isEqualTo("jwtId");
222     assertThat(newVerifier.verifyAndDecode(oldSignedCompact, validator).getJwtId())
223         .isEqualTo("jwtId");
224     assertThat(newVerifier.verifyAndDecode(newSignedCompact, validator).getJwtId())
225         .isEqualTo("jwtId");
226     assertThrows(
227         GeneralSecurityException.class,
228         () -> oldVerifier.verifyAndDecode(newSignedCompact, validator));
229   }
230 
231   // Note: we use Theory as a parametrized test -- different from what the Theory framework intends.
232   @Theory
wrongKey_throwsInvalidSignatureException( @romDataPoints"templateNames") String templateName)233   public void wrongKey_throwsInvalidSignatureException(
234       @FromDataPoints("templateNames") String templateName) throws Exception {
235     if (TestUtil.isTsan()) {
236       // KeysetHandle.generateNew is too slow in Tsan.
237       // We do not use assume because Theories expects to find something which is not skipped.
238       return;
239     }
240     KeyTemplate template = KeyTemplates.get(templateName);
241     KeysetHandle keysetHandle = KeysetHandle.generateNew(template);
242     JwtPublicKeySign jwtSign = keysetHandle.getPrimitive(JwtPublicKeySign.class);
243     RawJwt rawJwt = RawJwt.newBuilder().withoutExpiration().build();
244     String compact = jwtSign.signAndEncode(rawJwt);
245     JwtValidator validator = JwtValidator.newBuilder().allowMissingExpiration().build();
246 
247     KeysetHandle wrongKeysetHandle = KeysetHandle.generateNew(template);
248     KeysetHandle wrongPublicKeysetHandle = wrongKeysetHandle.getPublicKeysetHandle();
249 
250     JwtPublicKeyVerify wrongJwtVerify =
251         wrongPublicKeysetHandle.getPrimitive(JwtPublicKeyVerify.class);
252     assertThrows(
253         GeneralSecurityException.class, () -> wrongJwtVerify.verifyAndDecode(compact, validator));
254   }
255 
256   @Test
wrongIssuer_throwsInvalidException()257   public void wrongIssuer_throwsInvalidException() throws Exception {
258     KeyTemplate template = KeyTemplates.get("JWT_ES256");
259     KeysetHandle keysetHandle = KeysetHandle.generateNew(template);
260     JwtPublicKeySign jwtSigner = keysetHandle.getPrimitive(JwtPublicKeySign.class);
261     KeysetHandle publicHandle = keysetHandle.getPublicKeysetHandle();
262     JwtPublicKeyVerify jwtVerifier = publicHandle.getPrimitive(JwtPublicKeyVerify.class);
263     RawJwt rawJwt = RawJwt.newBuilder().setIssuer("Justus").withoutExpiration().build();
264     String compact = jwtSigner.signAndEncode(rawJwt);
265     JwtValidator validator =
266         JwtValidator.newBuilder().expectIssuer("Peter").allowMissingExpiration().build();
267     assertThrows(JwtInvalidException.class, () -> jwtVerifier.verifyAndDecode(compact, validator));
268   }
269 
270   @Test
expiredCompact_throwsInvalidException()271   public void expiredCompact_throwsInvalidException() throws Exception {
272     KeyTemplate template = KeyTemplates.get("JWT_ES256");
273     KeysetHandle keysetHandle = KeysetHandle.generateNew(template);
274     JwtPublicKeySign jwtSigner = keysetHandle.getPrimitive(JwtPublicKeySign.class);
275     KeysetHandle publicHandle = keysetHandle.getPublicKeysetHandle();
276     JwtPublicKeyVerify jwtVerifier = publicHandle.getPrimitive(JwtPublicKeyVerify.class);
277 
278     Instant now = Clock.systemUTC().instant().truncatedTo(ChronoUnit.SECONDS);
279     RawJwt rawJwt =
280         RawJwt.newBuilder()
281             .setExpiration(now.minusSeconds(100)) // exipired 100 seconds ago
282             .setIssuedAt(now.minusSeconds(200))
283             .build();
284     String compact = jwtSigner.signAndEncode(rawJwt);
285     JwtValidator validator = JwtValidator.newBuilder().build();
286     assertThrows(JwtInvalidException.class, () -> jwtVerifier.verifyAndDecode(compact, validator));
287   }
288 
289   @Test
notYetValidCompact_throwsInvalidException()290   public void notYetValidCompact_throwsInvalidException() throws Exception {
291     KeyTemplate template = KeyTemplates.get("JWT_ES256");
292     KeysetHandle keysetHandle = KeysetHandle.generateNew(template);
293     JwtPublicKeySign jwtSigner = keysetHandle.getPrimitive(JwtPublicKeySign.class);
294     KeysetHandle publicHandle = keysetHandle.getPublicKeysetHandle();
295     JwtPublicKeyVerify jwtVerifier = publicHandle.getPrimitive(JwtPublicKeyVerify.class);
296 
297     Instant now = Clock.systemUTC().instant().truncatedTo(ChronoUnit.SECONDS);
298     RawJwt rawJwt =
299         RawJwt.newBuilder()
300             .setNotBefore(now.plusSeconds(3600)) // is valid in 1 hour, but not before
301             .setIssuedAt(now)
302             .withoutExpiration()
303             .build();
304     String compact = jwtSigner.signAndEncode(rawJwt);
305     JwtValidator validator = JwtValidator.newBuilder().allowMissingExpiration().build();
306     assertThrows(JwtInvalidException.class, () -> jwtVerifier.verifyAndDecode(compact, validator));
307   }
308 
309   /* TODO: b/252792776. All keysets without primary should be rejected in every case. */
310   @Test
test_verifyWithoutPrimary_works()311   public void test_verifyWithoutPrimary_works() throws Exception {
312     Parameters parameters = KeyTemplates.get("JWT_ES256").toParameters();
313     KeysetHandle handle =
314         KeysetHandle.newBuilder()
315             .addEntry(
316                 KeysetHandle.generateEntryFromParameters(parameters).withRandomId().makePrimary())
317             .addEntry(KeysetHandle.generateEntryFromParameters(parameters).withRandomId())
318             .build();
319     KeysetHandle publicHandle = handle.getPublicKeysetHandle();
320     Keyset publicKeyset =
321         Keyset.parseFrom(TinkProtoKeysetFormat.serializeKeysetWithoutSecret(publicHandle));
322     Keyset publicKeysetWithoutPrimary = publicKeyset.toBuilder().setPrimaryKeyId(0).build();
323     // TODO(b/252792776): Optimally, this would throw.
324     KeysetHandle publicHandleWithoutPrimary =
325         TinkProtoKeysetFormat.parseKeysetWithoutSecret(publicKeysetWithoutPrimary.toByteArray());
326 
327     JwtPublicKeySign signer = handle.getPrimitive(JwtPublicKeySign.class);
328     // TODO(b/252792776): At least this should throw.
329     JwtPublicKeyVerify verifier = publicHandleWithoutPrimary.getPrimitive(JwtPublicKeyVerify.class);
330     RawJwt rawToken = RawJwt.newBuilder().setJwtId("blah").withoutExpiration().build();
331     String signedCompact = signer.signAndEncode(rawToken);
332     JwtValidator validator = JwtValidator.newBuilder().allowMissingExpiration().build();
333     VerifiedJwt verifiedToken = verifier.verifyAndDecode(signedCompact, validator);
334     assertThat(verifiedToken.getJwtId()).isEqualTo("blah");
335   }
336 }
337