1 /* 2 * Copyright 2020 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.xds.internal.security.certprovider; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CA_PEM_FILE; 21 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; 22 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; 23 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; 24 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; 25 import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; 26 import static java.nio.file.StandardCopyOption.REPLACE_EXISTING; 27 import static org.mockito.ArgumentMatchers.any; 28 import static org.mockito.ArgumentMatchers.eq; 29 import static org.mockito.Mockito.doReturn; 30 import static org.mockito.Mockito.never; 31 import static org.mockito.Mockito.reset; 32 import static org.mockito.Mockito.times; 33 import static org.mockito.Mockito.verify; 34 35 import io.grpc.Status; 36 import io.grpc.internal.TimeProvider; 37 import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; 38 import io.grpc.xds.internal.security.certprovider.CertificateProvider.DistributorWatcher; 39 import java.io.File; 40 import java.io.IOException; 41 import java.nio.file.Files; 42 import java.nio.file.NoSuchFileException; 43 import java.nio.file.Paths; 44 import java.nio.file.attribute.FileTime; 45 import java.security.PrivateKey; 46 import java.security.cert.CertificateException; 47 import java.security.cert.X509Certificate; 48 import java.util.ArrayList; 49 import java.util.List; 50 import java.util.concurrent.Delayed; 51 import java.util.concurrent.ScheduledExecutorService; 52 import java.util.concurrent.ScheduledFuture; 53 import java.util.concurrent.TimeUnit; 54 import org.junit.Before; 55 import org.junit.Rule; 56 import org.junit.Test; 57 import org.junit.rules.TemporaryFolder; 58 import org.junit.runner.RunWith; 59 import org.junit.runners.JUnit4; 60 import org.mockito.ArgumentCaptor; 61 import org.mockito.ArgumentMatchers; 62 import org.mockito.Mock; 63 import org.mockito.junit.MockitoJUnit; 64 import org.mockito.junit.MockitoRule; 65 66 /** Unit tests for {@link FileWatcherCertificateProvider}. */ 67 @RunWith(JUnit4.class) 68 public class FileWatcherCertificateProviderTest { 69 /** 70 * Expire time of cert SERVER_0_PEM_FILE. 71 */ 72 static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L; 73 private static final String CERT_FILE = "cert.pem"; 74 private static final String KEY_FILE = "key.pem"; 75 private static final String ROOT_FILE = "root.pem"; 76 77 @Mock private CertificateProvider.Watcher mockWatcher; 78 @Mock private ScheduledExecutorService timeService; 79 private final FakeTimeProvider timeProvider = new FakeTimeProvider(); 80 81 @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); 82 @Rule public final MockitoRule mocks = MockitoJUnit.rule(); 83 84 private String certFile; 85 private String keyFile; 86 private String rootFile; 87 88 private FileWatcherCertificateProvider provider; 89 90 @Before setUp()91 public void setUp() throws IOException { 92 DistributorWatcher watcher = new DistributorWatcher(); 93 watcher.addWatcher(mockWatcher); 94 95 certFile = new File(tempFolder.getRoot(), CERT_FILE).getAbsolutePath(); 96 keyFile = new File(tempFolder.getRoot(), KEY_FILE).getAbsolutePath(); 97 rootFile = new File(tempFolder.getRoot(), ROOT_FILE).getAbsolutePath(); 98 provider = 99 new FileWatcherCertificateProvider( 100 watcher, true, certFile, keyFile, rootFile, 600L, timeService, timeProvider); 101 } 102 populateTarget( String certFileSource, String keyFileSource, String rootFileSource, boolean deleteCurCert, boolean deleteCurKey, boolean deleteCurRoot)103 private void populateTarget( 104 String certFileSource, 105 String keyFileSource, 106 String rootFileSource, 107 boolean deleteCurCert, 108 boolean deleteCurKey, 109 boolean deleteCurRoot) 110 throws IOException { 111 if (deleteCurCert) { 112 Files.delete(Paths.get(certFile)); 113 } 114 if (certFileSource != null) { 115 certFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(certFileSource); 116 Files.copy(Paths.get(certFileSource), Paths.get(certFile), REPLACE_EXISTING); 117 Files.setLastModifiedTime( 118 Paths.get(certFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); 119 } 120 if (deleteCurKey) { 121 Files.delete(Paths.get(keyFile)); 122 } 123 if (keyFileSource != null) { 124 keyFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(keyFileSource); 125 Files.copy(Paths.get(keyFileSource), Paths.get(keyFile), REPLACE_EXISTING); 126 Files.setLastModifiedTime( 127 Paths.get(keyFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); 128 } 129 if (deleteCurRoot) { 130 Files.delete(Paths.get(rootFile)); 131 } 132 if (rootFileSource != null) { 133 rootFileSource = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(rootFileSource); 134 Files.copy(Paths.get(rootFileSource), Paths.get(rootFile), REPLACE_EXISTING); 135 Files.setLastModifiedTime( 136 Paths.get(rootFile), FileTime.fromMillis(timeProvider.currentTimeMillis())); 137 } 138 } 139 140 @Test getCertificateAndCheckUpdates()141 public void getCertificateAndCheckUpdates() throws IOException, CertificateException { 142 TestScheduledFuture<?> scheduledFuture = 143 new TestScheduledFuture<>(); 144 doReturn(scheduledFuture) 145 .when(timeService) 146 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 147 populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); 148 provider.checkAndReloadCertificates(); 149 verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE); 150 verifyTimeServiceAndScheduledFuture(); 151 152 reset(mockWatcher, timeService); 153 doReturn(scheduledFuture) 154 .when(timeService) 155 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 156 provider.checkAndReloadCertificates(); 157 verifyWatcherErrorUpdates(null, null, 0, 0, (String[]) null); 158 verifyTimeServiceAndScheduledFuture(); 159 } 160 161 @Test allUpdateSecondTime()162 public void allUpdateSecondTime() throws IOException, CertificateException, InterruptedException { 163 TestScheduledFuture<?> scheduledFuture = 164 new TestScheduledFuture<>(); 165 doReturn(scheduledFuture) 166 .when(timeService) 167 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 168 populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); 169 provider.checkAndReloadCertificates(); 170 171 reset(mockWatcher, timeService); 172 doReturn(scheduledFuture) 173 .when(timeService) 174 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 175 timeProvider.forwardTime(1, TimeUnit.SECONDS); 176 populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); 177 provider.checkAndReloadCertificates(); 178 verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); 179 verifyTimeServiceAndScheduledFuture(); 180 } 181 182 @Test closeDoesNotScheduleNext()183 public void closeDoesNotScheduleNext() throws IOException, CertificateException { 184 TestScheduledFuture<?> scheduledFuture = 185 new TestScheduledFuture<>(); 186 doReturn(scheduledFuture) 187 .when(timeService) 188 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 189 populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); 190 provider.close(); 191 provider.checkAndReloadCertificates(); 192 verify(mockWatcher, never()) 193 .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList()); 194 verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList()); 195 verify(timeService, never()).schedule(any(Runnable.class), any(Long.TYPE), any(TimeUnit.class)); 196 verify(timeService, times(1)).shutdownNow(); 197 } 198 199 200 @Test rootFileUpdateOnly()201 public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException { 202 TestScheduledFuture<?> scheduledFuture = 203 new TestScheduledFuture<>(); 204 doReturn(scheduledFuture) 205 .when(timeService) 206 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 207 populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); 208 provider.checkAndReloadCertificates(); 209 210 reset(mockWatcher, timeService); 211 doReturn(scheduledFuture) 212 .when(timeService) 213 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 214 timeProvider.forwardTime(1, TimeUnit.SECONDS); 215 populateTarget(null, null, SERVER_1_PEM_FILE, false, false, false); 216 provider.checkAndReloadCertificates(); 217 verifyWatcherUpdates(null, SERVER_1_PEM_FILE); 218 verifyTimeServiceAndScheduledFuture(); 219 } 220 221 @Test certAndKeyFileUpdateOnly()222 public void certAndKeyFileUpdateOnly() 223 throws IOException, CertificateException, InterruptedException { 224 TestScheduledFuture<?> scheduledFuture = 225 new TestScheduledFuture<>(); 226 doReturn(scheduledFuture) 227 .when(timeService) 228 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 229 populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); 230 provider.checkAndReloadCertificates(); 231 232 reset(mockWatcher, timeService); 233 doReturn(scheduledFuture) 234 .when(timeService) 235 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 236 timeProvider.forwardTime(1, TimeUnit.SECONDS); 237 populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, null, false, false, false); 238 provider.checkAndReloadCertificates(); 239 verifyWatcherUpdates(SERVER_0_PEM_FILE, null); 240 verifyTimeServiceAndScheduledFuture(); 241 } 242 243 @Test getCertificate_initialMissingCertFile()244 public void getCertificate_initialMissingCertFile() throws IOException { 245 TestScheduledFuture<?> scheduledFuture = 246 new TestScheduledFuture<>(); 247 doReturn(scheduledFuture) 248 .when(timeService) 249 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 250 populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, false, false); 251 provider.checkAndReloadCertificates(); 252 verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 0, 1, "cert.pem"); 253 } 254 255 @Test getCertificate_missingCertFile()256 public void getCertificate_missingCertFile() throws IOException, InterruptedException { 257 commonErrorTest( 258 null, CLIENT_KEY_FILE, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "cert.pem"); 259 } 260 261 @Test getCertificate_missingKeyFile()262 public void getCertificate_missingKeyFile() throws IOException, InterruptedException { 263 commonErrorTest( 264 CLIENT_PEM_FILE, null, CA_PEM_FILE, NoSuchFileException.class, 0, 1, 0, 0, "key.pem"); 265 } 266 267 @Test getCertificate_badKeyFile()268 public void getCertificate_badKeyFile() throws IOException, InterruptedException { 269 commonErrorTest( 270 CLIENT_PEM_FILE, 271 SERVER_0_PEM_FILE, 272 CA_PEM_FILE, 273 java.security.KeyException.class, 274 0, 275 1, 276 0, 277 0, 278 "could not find a PKCS #8 private key in input stream"); 279 } 280 281 @Test getCertificate_missingRootFile()282 public void getCertificate_missingRootFile() throws IOException, InterruptedException { 283 TestScheduledFuture<?> scheduledFuture = 284 new TestScheduledFuture<>(); 285 doReturn(scheduledFuture) 286 .when(timeService) 287 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 288 populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); 289 provider.checkAndReloadCertificates(); 290 291 reset(mockWatcher); 292 timeProvider.forwardTime(1, TimeUnit.SECONDS); 293 populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, null, false, false, true); 294 timeProvider.forwardTime( 295 CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(), 296 TimeUnit.MILLISECONDS); 297 provider.checkAndReloadCertificates(); 298 verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem"); 299 } 300 commonErrorTest( String certFile, String keyFile, String rootFile, Class<?> throwableType, int firstUpdateCertCount, int firstUpdateRootCount, int secondUpdateCertCount, int secondUpdateRootCount, String... causeMessages)301 private void commonErrorTest( 302 String certFile, 303 String keyFile, 304 String rootFile, 305 Class<?> throwableType, 306 int firstUpdateCertCount, 307 int firstUpdateRootCount, 308 int secondUpdateCertCount, 309 int secondUpdateRootCount, 310 String... causeMessages) 311 throws IOException, InterruptedException { 312 TestScheduledFuture<?> scheduledFuture = 313 new TestScheduledFuture<>(); 314 doReturn(scheduledFuture) 315 .when(timeService) 316 .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); 317 populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, false, false); 318 provider.checkAndReloadCertificates(); 319 320 reset(mockWatcher); 321 timeProvider.forwardTime(1, TimeUnit.SECONDS); 322 populateTarget( 323 certFile, keyFile, rootFile, certFile == null, keyFile == null, rootFile == null); 324 timeProvider.forwardTime( 325 CERT0_EXPIRY_TIME_MILLIS - 610_000L - timeProvider.currentTimeMillis(), 326 TimeUnit.MILLISECONDS); 327 provider.checkAndReloadCertificates(); 328 verifyWatcherErrorUpdates( 329 null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null); 330 331 reset(mockWatcher); 332 timeProvider.forwardTime(20, TimeUnit.SECONDS); 333 provider.checkAndReloadCertificates(); 334 verifyWatcherErrorUpdates( 335 Status.Code.UNKNOWN, 336 throwableType, 337 secondUpdateCertCount, 338 secondUpdateRootCount, 339 causeMessages); 340 } 341 verifyWatcherErrorUpdates( Status.Code code, Class<?> throwableType, int updateCertCount, int updateRootCount, String... causeMessages)342 private void verifyWatcherErrorUpdates( 343 Status.Code code, 344 Class<?> throwableType, 345 int updateCertCount, 346 int updateRootCount, 347 String... causeMessages) { 348 verify(mockWatcher, times(updateCertCount)) 349 .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList()); 350 verify(mockWatcher, times(updateRootCount)) 351 .updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList()); 352 if (code == null && throwableType == null && causeMessages == null) { 353 verify(mockWatcher, never()).onError(any(Status.class)); 354 } else { 355 ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class); 356 verify(mockWatcher, times(1)).onError(statusCaptor.capture()); 357 Status status = statusCaptor.getValue(); 358 assertThat(status.getCode()).isEqualTo(code); 359 Throwable cause = status.getCause(); 360 assertThat(cause).isInstanceOf(throwableType); 361 for (String causeMessage : causeMessages) { 362 assertThat(cause).hasMessageThat().contains(causeMessage); 363 cause = cause.getCause(); 364 } 365 } 366 } 367 verifyTimeServiceAndScheduledFuture()368 private void verifyTimeServiceAndScheduledFuture() { 369 verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS)); 370 assertThat(provider.scheduledFuture).isNotNull(); 371 assertThat(provider.scheduledFuture.isDone()).isFalse(); 372 assertThat(provider.scheduledFuture.isCancelled()).isFalse(); 373 } 374 verifyWatcherUpdates(String certPemFile, String rootPemFile)375 private void verifyWatcherUpdates(String certPemFile, String rootPemFile) 376 throws IOException, CertificateException { 377 if (certPemFile != null) { 378 @SuppressWarnings("unchecked") 379 ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(List.class); 380 verify(mockWatcher, times(1)) 381 .updateCertificate(any(PrivateKey.class), certChainCaptor.capture()); 382 List<X509Certificate> certChain = certChainCaptor.getValue(); 383 assertThat(certChain).hasSize(1); 384 assertThat(certChain.get(0)) 385 .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(certPemFile)); 386 } else { 387 verify(mockWatcher, never()) 388 .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList()); 389 } 390 if (rootPemFile != null) { 391 @SuppressWarnings("unchecked") 392 ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(List.class); 393 verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture()); 394 List<X509Certificate> roots = rootsCaptor.getValue(); 395 assertThat(roots).hasSize(1); 396 assertThat(roots.get(0)) 397 .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(rootPemFile)); 398 verify(mockWatcher, never()).onError(any(Status.class)); 399 } else { 400 verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList()); 401 } 402 } 403 404 static class TestScheduledFuture<V> implements ScheduledFuture<V> { 405 406 static class Record { 407 long timeout; 408 TimeUnit unit; 409 Record(long timeout, TimeUnit unit)410 Record(long timeout, TimeUnit unit) { 411 this.timeout = timeout; 412 this.unit = unit; 413 } 414 } 415 416 ArrayList<Record> calls = new ArrayList<>(); 417 418 @Override getDelay(TimeUnit unit)419 public long getDelay(TimeUnit unit) { 420 return 0; 421 } 422 423 @Override compareTo(Delayed o)424 public int compareTo(Delayed o) { 425 return 0; 426 } 427 428 @Override cancel(boolean mayInterruptIfRunning)429 public boolean cancel(boolean mayInterruptIfRunning) { 430 return false; 431 } 432 433 @Override isCancelled()434 public boolean isCancelled() { 435 return false; 436 } 437 438 @Override isDone()439 public boolean isDone() { 440 return false; 441 } 442 443 @Override get()444 public V get() { 445 return null; 446 } 447 448 @Override get(long timeout, TimeUnit unit)449 public V get(long timeout, TimeUnit unit) { 450 calls.add(new Record(timeout, unit)); 451 return null; 452 } 453 } 454 455 /** 456 * Fake TimeProvider that roughly mirrors FakeClock. Not using FakeClock because it incorrectly 457 * fails to align the wall-time API TimeProvider.currentTimeNanos() with currentTimeMillis() and 458 * fixing it upsets a _lot_ of tests. 459 */ 460 static class FakeTimeProvider implements TimeProvider { 461 public long currentTimeNanos = TimeUnit.SECONDS.toNanos(1262332800); /* 2010-01-01 */ 462 currentTimeNanos()463 @Override public long currentTimeNanos() { 464 return currentTimeNanos; 465 } 466 forwardTime(long duration, TimeUnit unit)467 public void forwardTime(long duration, TimeUnit unit) { 468 currentTimeNanos += unit.toNanos(duration); 469 } 470 currentTimeMillis()471 public long currentTimeMillis() { 472 return TimeUnit.NANOSECONDS.toMillis(currentTimeNanos); 473 } 474 } 475 } 476