1 /*
2  * Copyright 2020 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.xds.internal.security.certprovider;
18 
19 import static com.google.common.base.Preconditions.checkNotNull;
20 
21 import com.google.common.annotations.VisibleForTesting;
22 import io.grpc.Status;
23 import io.grpc.internal.TimeProvider;
24 import io.grpc.xds.internal.security.trust.CertificateUtils;
25 import java.io.ByteArrayInputStream;
26 import java.nio.file.Files;
27 import java.nio.file.Path;
28 import java.nio.file.Paths;
29 import java.nio.file.attribute.FileTime;
30 import java.security.PrivateKey;
31 import java.security.cert.X509Certificate;
32 import java.util.Arrays;
33 import java.util.concurrent.ScheduledExecutorService;
34 import java.util.concurrent.ScheduledFuture;
35 import java.util.concurrent.TimeUnit;
36 import java.util.logging.Level;
37 import java.util.logging.Logger;
38 
39 // TODO(sanjaypujare): abstract out common functionality into an an abstract superclass
40 /** Implementation of {@link CertificateProvider} for file watching cert provider. */
41 final class FileWatcherCertificateProvider extends CertificateProvider implements Runnable {
42   private static final Logger logger =
43       Logger.getLogger(FileWatcherCertificateProvider.class.getName());
44 
45   private final ScheduledExecutorService scheduledExecutorService;
46   private final TimeProvider timeProvider;
47   private final Path certFile;
48   private final Path keyFile;
49   private final Path trustFile;
50   private final long refreshIntervalInSeconds;
51   @VisibleForTesting ScheduledFuture<?> scheduledFuture;
52   private FileTime lastModifiedTimeCert;
53   private FileTime lastModifiedTimeKey;
54   private FileTime lastModifiedTimeRoot;
55   private boolean shutdown;
56 
FileWatcherCertificateProvider( DistributorWatcher watcher, boolean notifyCertUpdates, String certFile, String keyFile, String trustFile, long refreshIntervalInSeconds, ScheduledExecutorService scheduledExecutorService, TimeProvider timeProvider)57   FileWatcherCertificateProvider(
58       DistributorWatcher watcher,
59       boolean notifyCertUpdates,
60       String certFile,
61       String keyFile,
62       String trustFile,
63       long refreshIntervalInSeconds,
64       ScheduledExecutorService scheduledExecutorService,
65       TimeProvider timeProvider) {
66     super(watcher, notifyCertUpdates);
67     this.scheduledExecutorService =
68         checkNotNull(scheduledExecutorService, "scheduledExecutorService");
69     this.timeProvider = checkNotNull(timeProvider, "timeProvider");
70     this.certFile = Paths.get(checkNotNull(certFile, "certFile"));
71     this.keyFile = Paths.get(checkNotNull(keyFile, "keyFile"));
72     this.trustFile = Paths.get(checkNotNull(trustFile, "trustFile"));
73     this.refreshIntervalInSeconds = refreshIntervalInSeconds;
74   }
75 
76   @Override
start()77   public void start() {
78     scheduleNextRefreshCertificate(/* delayInSeconds= */0);
79   }
80 
81   @Override
close()82   public synchronized void close() {
83     shutdown = true;
84     scheduledExecutorService.shutdownNow();
85     if (scheduledFuture != null) {
86       scheduledFuture.cancel(true);
87       scheduledFuture = null;
88     }
89     getWatcher().close();
90   }
91 
scheduleNextRefreshCertificate(long delayInSeconds)92   private synchronized void scheduleNextRefreshCertificate(long delayInSeconds) {
93     if (!shutdown) {
94       scheduledFuture = scheduledExecutorService.schedule(this, delayInSeconds, TimeUnit.SECONDS);
95     }
96   }
97 
98   @VisibleForTesting
checkAndReloadCertificates()99   void checkAndReloadCertificates() {
100     try {
101       try {
102         FileTime currentCertTime = Files.getLastModifiedTime(certFile);
103         FileTime currentKeyTime = Files.getLastModifiedTime(keyFile);
104         if (!currentCertTime.equals(lastModifiedTimeCert)
105             && !currentKeyTime.equals(lastModifiedTimeKey)) {
106           byte[] certFileContents = Files.readAllBytes(certFile);
107           byte[] keyFileContents = Files.readAllBytes(keyFile);
108           FileTime currentCertTime2 = Files.getLastModifiedTime(certFile);
109           FileTime currentKeyTime2 = Files.getLastModifiedTime(keyFile);
110           if (!currentCertTime2.equals(currentCertTime)) {
111             return;
112           }
113           if (!currentKeyTime2.equals(currentKeyTime)) {
114             return;
115           }
116           try (ByteArrayInputStream certStream = new ByteArrayInputStream(certFileContents);
117               ByteArrayInputStream keyStream = new ByteArrayInputStream(keyFileContents)) {
118             PrivateKey privateKey = CertificateUtils.getPrivateKey(keyStream);
119             X509Certificate[] certs = CertificateUtils.toX509Certificates(certStream);
120             getWatcher().updateCertificate(privateKey, Arrays.asList(certs));
121           }
122           lastModifiedTimeCert = currentCertTime;
123           lastModifiedTimeKey = currentKeyTime;
124         }
125       } catch (Throwable t) {
126         generateErrorIfCurrentCertExpired(t);
127       }
128       try {
129         FileTime currentRootTime = Files.getLastModifiedTime(trustFile);
130         if (currentRootTime.equals(lastModifiedTimeRoot)) {
131           return;
132         }
133         byte[] rootFileContents = Files.readAllBytes(trustFile);
134         FileTime currentRootTime2 = Files.getLastModifiedTime(trustFile);
135         if (!currentRootTime2.equals(currentRootTime)) {
136           return;
137         }
138         try (ByteArrayInputStream rootStream = new ByteArrayInputStream(rootFileContents)) {
139           X509Certificate[] caCerts = CertificateUtils.toX509Certificates(rootStream);
140           getWatcher().updateTrustedRoots(Arrays.asList(caCerts));
141         }
142         lastModifiedTimeRoot = currentRootTime;
143       } catch (Throwable t) {
144         getWatcher().onError(Status.fromThrowable(t));
145       }
146     } finally {
147       scheduleNextRefreshCertificate(refreshIntervalInSeconds);
148     }
149   }
150 
generateErrorIfCurrentCertExpired(Throwable t)151   private void generateErrorIfCurrentCertExpired(Throwable t) {
152     X509Certificate currentCert = getWatcher().getLastIdentityCert();
153     if (currentCert != null) {
154       long delaySeconds = computeDelaySecondsToCertExpiry(currentCert);
155       if (delaySeconds > refreshIntervalInSeconds) {
156         logger.log(Level.FINER, "reload certificate error", t);
157         return;
158       }
159       // The current cert is going to expire in less than {@link refreshIntervalInSeconds}
160       // Clear the current cert and notify our watchers thru {@code onError}
161       getWatcher().clearValues();
162     }
163     getWatcher().onError(Status.fromThrowable(t));
164   }
165 
166   @SuppressWarnings("JdkObsolete")
computeDelaySecondsToCertExpiry(X509Certificate lastCert)167   private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) {
168     checkNotNull(lastCert, "lastCert");
169     return TimeUnit.NANOSECONDS.toSeconds(
170         TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime())
171             - timeProvider.currentTimeNanos());
172   }
173 
174   @Override
run()175   public void run() {
176     if (!shutdown) {
177       try {
178         checkAndReloadCertificates();
179       } catch (Throwable t) {
180         logger.log(Level.SEVERE, "Uncaught exception!", t);
181         if (t instanceof InterruptedException) {
182           Thread.currentThread().interrupt();
183         }
184       }
185     }
186   }
187 
188   abstract static class Factory {
189     private static final Factory DEFAULT_INSTANCE =
190         new Factory() {
191           @Override
192           FileWatcherCertificateProvider create(
193               DistributorWatcher watcher,
194               boolean notifyCertUpdates,
195               String certFile,
196               String keyFile,
197               String trustFile,
198               long refreshIntervalInSeconds,
199               ScheduledExecutorService scheduledExecutorService,
200               TimeProvider timeProvider) {
201             return new FileWatcherCertificateProvider(
202                 watcher,
203                 notifyCertUpdates,
204                 certFile,
205                 keyFile,
206                 trustFile,
207                 refreshIntervalInSeconds,
208                 scheduledExecutorService,
209                 timeProvider);
210           }
211         };
212 
getInstance()213     static Factory getInstance() {
214       return DEFAULT_INSTANCE;
215     }
216 
create( DistributorWatcher watcher, boolean notifyCertUpdates, String certFile, String keyFile, String trustFile, long refreshIntervalInSeconds, ScheduledExecutorService scheduledExecutorService, TimeProvider timeProvider)217     abstract FileWatcherCertificateProvider create(
218         DistributorWatcher watcher,
219         boolean notifyCertUpdates,
220         String certFile,
221         String keyFile,
222         String trustFile,
223         long refreshIntervalInSeconds,
224         ScheduledExecutorService scheduledExecutorService,
225         TimeProvider timeProvider);
226   }
227 }
228