1 /*
2  * Copyright 2020 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.xds.internal.security.certprovider;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE;
21 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE;
22 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
23 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
24 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
25 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
26 import static java.nio.file.StandardCopyOption.REPLACE_EXISTING;
27 import static org.mockito.ArgumentMatchers.any;
28 import static org.mockito.ArgumentMatchers.eq;
29 import static org.mockito.Mockito.doReturn;
30 import static org.mockito.Mockito.never;
31 import static org.mockito.Mockito.reset;
32 import static org.mockito.Mockito.times;
33 import static org.mockito.Mockito.verify;
34 
35 import io.grpc.Status;
36 import io.grpc.internal.TimeProvider;
37 import io.grpc.xds.internal.security.CommonTlsContextTestsUtil;
38 import io.grpc.xds.internal.security.certprovider.CertificateProvider.DistributorWatcher;
39 import java.io.File;
40 import java.io.IOException;
41 import java.nio.file.Files;
42 import java.nio.file.NoSuchFileException;
43 import java.nio.file.Paths;
44 import java.nio.file.attribute.FileTime;
45 import java.security.PrivateKey;
46 import java.security.cert.CertificateException;
47 import java.security.cert.X509Certificate;
48 import java.util.ArrayList;
49 import java.util.List;
50 import java.util.concurrent.Delayed;
51 import java.util.concurrent.ScheduledExecutorService;
52 import java.util.concurrent.ScheduledFuture;
53 import java.util.concurrent.TimeUnit;
54 import org.junit.Before;
55 import org.junit.Rule;
56 import org.junit.Test;
57 import org.junit.rules.TemporaryFolder;
58 import org.junit.runner.RunWith;
59 import org.junit.runners.JUnit4;
60 import org.mockito.ArgumentCaptor;
61 import org.mockito.ArgumentMatchers;
62 import org.mockito.Mock;
63 import org.mockito.junit.MockitoJUnit;
64 import org.mockito.junit.MockitoRule;
65 
66 /** Unit tests for {@link FileWatcherCertificateProvider}. */
67 @RunWith(JUnit4.class)
68 public class FileWatcherCertificateProviderTest {
69   /**
70    * Expire time of cert SERVER_0_PEM_FILE.
71    */
72   static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L;
73   private static final String CERT_FILE = "cert.pem";
74   private static final String KEY_FILE = "key.pem";
75   private static final String ROOT_FILE = "root.pem";
76 
77   @Mock private CertificateProvider.Watcher mockWatcher;
78   @Mock private ScheduledExecutorService timeService;
79   private final FakeTimeProvider timeProvider = new FakeTimeProvider();
80 
81   @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
82   @Rule public final MockitoRule mocks = MockitoJUnit.rule();
83 
84   private String certFile;
85   private String keyFile;
86   private String rootFile;
87 
88   private FileWatcherCertificateProvider provider;
89 
90   @Before
setUp()91   public void setUp() throws IOException {
92     DistributorWatcher watcher = new DistributorWatcher();
93     watcher.addWatcher(mockWatcher);
94 
95     certFile = new File(tempFolder.getRoot(), CERT_FILE).getAbsolutePath();
96     keyFile = new File(tempFolder.getRoot(), KEY_FILE).getAbsolutePath();
97     rootFile = new File(tempFolder.getRoot(), ROOT_FILE).getAbsolutePath();
98     provider =
99         new FileWatcherCertificateProvider(
100             watcher, true, certFile, keyFile, rootFile, 600L, timeService, timeProvider);
101   }
102 
populateTarget( String certFileSource, String keyFileSource, String rootFileSource, boolean deleteCurCert, boolean deleteCurKey, boolean deleteCurRoot)103   private void populateTarget(
104       String certFileSource,
105       String keyFileSource,
106       String rootFileSource,
107       boolean deleteCurCert,
108       boolean deleteCurKey,
109       boolean deleteCurRoot)
110       throws IOException {
111     if (deleteCurCert) {
112       Files.delete(Paths.get(certFile));
113     }
114     if (certFileSource != null) {
115       certFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(certFileSource);
116       Files.copy(Paths.get(certFileSource), Paths.get(certFile), REPLACE_EXISTING);
117       Files.setLastModifiedTime(
118           Paths.get(certFile), FileTime.fromMillis(timeProvider.currentTimeMillis()));
119     }
120     if (deleteCurKey) {
121       Files.delete(Paths.get(keyFile));
122     }
123     if (keyFileSource != null) {
124       keyFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(keyFileSource);
125       Files.copy(Paths.get(keyFileSource), Paths.get(keyFile), REPLACE_EXISTING);
126       Files.setLastModifiedTime(
127           Paths.get(keyFile), FileTime.fromMillis(timeProvider.currentTimeMillis()));
128     }
129     if (deleteCurRoot) {
130       Files.delete(Paths.get(rootFile));
131     }
132     if (rootFileSource != null) {
133       rootFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(rootFileSource);
134       Files.copy(Paths.get(rootFileSource), Paths.get(rootFile), REPLACE_EXISTING);
135       Files.setLastModifiedTime(
136           Paths.get(rootFile), FileTime.fromMillis(timeProvider.currentTimeMillis()));
137     }
138   }
139 
140   @Test
getCertificateAndCheckUpdates()141   public void getCertificateAndCheckUpdates() throws IOException, CertificateException {
142     TestScheduledFuture<?> scheduledFuture =
143         new TestScheduledFuture<>();
144     doReturn(scheduledFuture)
145         .when(timeService)
146         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
147     populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
148     provider.checkAndReloadCertificates();
149     verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE);
150     verifyTimeServiceAndScheduledFuture();
151 
152     reset(mockWatcher, timeService);
153     doReturn(scheduledFuture)
154         .when(timeService)
155         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
156     provider.checkAndReloadCertificates();
157     verifyWatcherErrorUpdates(null, null, 0, 0, (String[]) null);
158     verifyTimeServiceAndScheduledFuture();
159   }
160 
161   @Test
allUpdateSecondTime()162   public void allUpdateSecondTime() throws IOException, CertificateException, InterruptedException {
163     TestScheduledFuture<?> scheduledFuture =
164         new TestScheduledFuture<>();
165     doReturn(scheduledFuture)
166         .when(timeService)
167         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
168     populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
169     provider.checkAndReloadCertificates();
170 
171     reset(mockWatcher, timeService);
172     doReturn(scheduledFuture)
173         .when(timeService)
174         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
175     timeProvider.forwardTime(1, TimeUnit.SECONDS);
176     populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false);
177     provider.checkAndReloadCertificates();
178     verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE);
179     verifyTimeServiceAndScheduledFuture();
180   }
181 
182   @Test
closeDoesNotScheduleNext()183   public void closeDoesNotScheduleNext() throws IOException, CertificateException {
184     TestScheduledFuture<?> scheduledFuture =
185             new TestScheduledFuture<>();
186     doReturn(scheduledFuture)
187             .when(timeService)
188             .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
189     populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
190     provider.close();
191     provider.checkAndReloadCertificates();
192     verify(mockWatcher, never())
193         .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
194     verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
195     verify(timeService, never()).schedule(any(Runnable.class), any(Long.TYPE), any(TimeUnit.class));
196     verify(timeService, times(1)).shutdownNow();
197   }
198 
199 
200   @Test
rootFileUpdateOnly()201   public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException {
202     TestScheduledFuture<?> scheduledFuture =
203         new TestScheduledFuture<>();
204     doReturn(scheduledFuture)
205         .when(timeService)
206         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
207     populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
208     provider.checkAndReloadCertificates();
209 
210     reset(mockWatcher, timeService);
211     doReturn(scheduledFuture)
212         .when(timeService)
213         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
214     timeProvider.forwardTime(1, TimeUnit.SECONDS);
215     populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false);
216     provider.checkAndReloadCertificates();
217     verifyWatcherUpdates(null, SERVER_1_PEM_FILE);
218     verifyTimeServiceAndScheduledFuture();
219   }
220 
221   @Test
certAndKeyFileUpdateOnly()222   public void certAndKeyFileUpdateOnly()
223       throws IOException, CertificateException, InterruptedException {
224     TestScheduledFuture<?> scheduledFuture =
225         new TestScheduledFuture<>();
226     doReturn(scheduledFuture)
227         .when(timeService)
228         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
229     populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
230     provider.checkAndReloadCertificates();
231 
232     reset(mockWatcher, timeService);
233     doReturn(scheduledFuture)
234         .when(timeService)
235         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
236     timeProvider.forwardTime(1, TimeUnit.SECONDS);
237     populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false);
238     provider.checkAndReloadCertificates();
239     verifyWatcherUpdates(SERVER_0_PEM_FILE, null);
240     verifyTimeServiceAndScheduledFuture();
241   }
242 
243   @Test
getCertificate_initialMissingCertFile()244   public void getCertificate_initialMissingCertFile() throws IOException {
245     TestScheduledFuture<?> scheduledFuture =
246         new TestScheduledFuture<>();
247     doReturn(scheduledFuture)
248         .when(timeService)
249         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
250     populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false);
251     provider.checkAndReloadCertificates();
252     verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 0, 1, "cert.pem");
253   }
254 
255   @Test
getCertificate_missingCertFile()256   public void getCertificate_missingCertFile() throws IOException, InterruptedException {
257     commonErrorTest(
258         null, CLIENT_KEY_FILE, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "cert.pem");
259   }
260 
261   @Test
getCertificate_missingKeyFile()262   public void getCertificate_missingKeyFile() throws IOException, InterruptedException {
263     commonErrorTest(
264         CLIENT_PEM_FILE, null, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "key.pem");
265   }
266 
267   @Test
getCertificate_badKeyFile()268   public void getCertificate_badKeyFile() throws IOException, InterruptedException {
269     commonErrorTest(
270         CLIENT_PEM_FILE,
271         SERVER_0_PEM_FILE,
272         CA_PEM_FILE,
273         java.security.KeyException.class,
274         0,
275         1,
276         0,
277         0,
278         "could not find a PKCS #8 private key in input stream");
279   }
280 
281   @Test
getCertificate_missingRootFile()282   public void getCertificate_missingRootFile() throws IOException, InterruptedException {
283     TestScheduledFuture<?> scheduledFuture =
284         new TestScheduledFuture<>();
285     doReturn(scheduledFuture)
286         .when(timeService)
287         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
288     populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false);
289     provider.checkAndReloadCertificates();
290 
291     reset(mockWatcher);
292     timeProvider.forwardTime(1, TimeUnit.SECONDS);
293     populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, false, false, true);
294     timeProvider.forwardTime(
295         CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(),
296         TimeUnit.MILLISECONDS);
297     provider.checkAndReloadCertificates();
298     verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem");
299   }
300 
commonErrorTest( String certFile, String keyFile, String rootFile, Class<?> throwableType, int firstUpdateCertCount, int firstUpdateRootCount, int secondUpdateCertCount, int secondUpdateRootCount, String... causeMessages)301   private void commonErrorTest(
302       String certFile,
303       String keyFile,
304       String rootFile,
305       Class<?> throwableType,
306       int firstUpdateCertCount,
307       int firstUpdateRootCount,
308       int secondUpdateCertCount,
309       int secondUpdateRootCount,
310       String... causeMessages)
311       throws IOException, InterruptedException {
312     TestScheduledFuture<?> scheduledFuture =
313         new TestScheduledFuture<>();
314     doReturn(scheduledFuture)
315         .when(timeService)
316         .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
317     populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false);
318     provider.checkAndReloadCertificates();
319 
320     reset(mockWatcher);
321     timeProvider.forwardTime(1, TimeUnit.SECONDS);
322     populateTarget(
323         certFile, keyFile, rootFile, certFile == null, keyFile == null, rootFile == null);
324     timeProvider.forwardTime(
325         CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(),
326         TimeUnit.MILLISECONDS);
327     provider.checkAndReloadCertificates();
328     verifyWatcherErrorUpdates(
329         null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null);
330 
331     reset(mockWatcher);
332     timeProvider.forwardTime(20, TimeUnit.SECONDS);
333     provider.checkAndReloadCertificates();
334     verifyWatcherErrorUpdates(
335         Status.Code.UNKNOWN,
336         throwableType,
337         secondUpdateCertCount,
338         secondUpdateRootCount,
339         causeMessages);
340   }
341 
verifyWatcherErrorUpdates( Status.Code code, Class<?> throwableType, int updateCertCount, int updateRootCount, String... causeMessages)342   private void verifyWatcherErrorUpdates(
343       Status.Code code,
344       Class<?> throwableType,
345       int updateCertCount,
346       int updateRootCount,
347       String... causeMessages) {
348     verify(mockWatcher, times(updateCertCount))
349         .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
350     verify(mockWatcher, times(updateRootCount))
351         .updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
352     if (code == null && throwableType == null && causeMessages == null) {
353       verify(mockWatcher, never()).onError(any(Status.class));
354     } else {
355       ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
356       verify(mockWatcher, times(1)).onError(statusCaptor.capture());
357       Status status = statusCaptor.getValue();
358       assertThat(status.getCode()).isEqualTo(code);
359       Throwable cause = status.getCause();
360       assertThat(cause).isInstanceOf(throwableType);
361       for (String causeMessage : causeMessages) {
362         assertThat(cause).hasMessageThat().contains(causeMessage);
363         cause = cause.getCause();
364       }
365     }
366   }
367 
verifyTimeServiceAndScheduledFuture()368   private void verifyTimeServiceAndScheduledFuture() {
369     verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS));
370     assertThat(provider.scheduledFuture).isNotNull();
371     assertThat(provider.scheduledFuture.isDone()).isFalse();
372     assertThat(provider.scheduledFuture.isCancelled()).isFalse();
373   }
374 
verifyWatcherUpdates(String certPemFile, String rootPemFile)375   private void verifyWatcherUpdates(String certPemFile, String rootPemFile)
376       throws IOException, CertificateException {
377     if (certPemFile != null) {
378       @SuppressWarnings("unchecked")
379       ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(List.class);
380       verify(mockWatcher, times(1))
381           .updateCertificate(any(PrivateKey.class), certChainCaptor.capture());
382       List<X509Certificate> certChain = certChainCaptor.getValue();
383       assertThat(certChain).hasSize(1);
384       assertThat(certChain.get(0))
385           .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(certPemFile));
386     } else {
387       verify(mockWatcher, never())
388           .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
389     }
390     if (rootPemFile != null) {
391       @SuppressWarnings("unchecked")
392       ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(List.class);
393       verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture());
394       List<X509Certificate> roots = rootsCaptor.getValue();
395       assertThat(roots).hasSize(1);
396       assertThat(roots.get(0))
397           .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(rootPemFile));
398       verify(mockWatcher, never()).onError(any(Status.class));
399     } else {
400       verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
401     }
402   }
403 
404   static class TestScheduledFuture<V> implements ScheduledFuture<V> {
405 
406     static class Record {
407       long timeout;
408       TimeUnit unit;
409 
Record(long timeout, TimeUnit unit)410       Record(long timeout, TimeUnit unit) {
411         this.timeout = timeout;
412         this.unit = unit;
413       }
414     }
415 
416     ArrayList<Record> calls = new ArrayList<>();
417 
418     @Override
getDelay(TimeUnit unit)419     public long getDelay(TimeUnit unit) {
420       return 0;
421     }
422 
423     @Override
compareTo(Delayed o)424     public int compareTo(Delayed o) {
425       return 0;
426     }
427 
428     @Override
cancel(boolean mayInterruptIfRunning)429     public boolean cancel(boolean mayInterruptIfRunning) {
430       return false;
431     }
432 
433     @Override
isCancelled()434     public boolean isCancelled() {
435       return false;
436     }
437 
438     @Override
isDone()439     public boolean isDone() {
440       return false;
441     }
442 
443     @Override
get()444     public V get() {
445       return null;
446     }
447 
448     @Override
get(long timeout, TimeUnit unit)449     public V get(long timeout, TimeUnit unit) {
450       calls.add(new Record(timeout, unit));
451       return null;
452     }
453   }
454 
455   /**
456    * Fake TimeProvider that roughly mirrors FakeClock. Not using FakeClock because it incorrectly
457    * fails to align the wall-time API TimeProvider.currentTimeNanos() with currentTimeMillis() and
458    * fixing it upsets a _lot_ of tests.
459    */
460   static class FakeTimeProvider implements TimeProvider {
461     public long currentTimeNanos = TimeUnit.SECONDS.toNanos(1262332800); /* 2010-01-01 */
462 
currentTimeNanos()463     @Override public long currentTimeNanos() {
464       return currentTimeNanos;
465     }
466 
forwardTime(long duration, TimeUnit unit)467     public void forwardTime(long duration, TimeUnit unit) {
468       currentTimeNanos += unit.toNanos(duration);
469     }
470 
currentTimeMillis()471     public long currentTimeMillis() {
472       return TimeUnit.NANOSECONDS.toMillis(currentTimeNanos);
473     }
474   }
475 }
476