1 /*
2  * Copyright 2020 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.xds.internal.security;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static java.nio.charset.StandardCharsets.UTF_8;
21 
22 import com.google.common.io.CharStreams;
23 import com.google.common.util.concurrent.MoreExecutors;
24 import com.google.protobuf.BoolValue;
25 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance;
26 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
27 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
28 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance;
29 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
30 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext;
31 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext;
32 import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
33 import io.grpc.internal.testing.TestUtils;
34 import io.grpc.testing.TlsTesting;
35 import io.grpc.xds.EnvoyServerProtoData;
36 import io.grpc.xds.internal.security.trust.CertificateUtils;
37 import io.netty.handler.ssl.SslContext;
38 import java.io.ByteArrayInputStream;
39 import java.io.IOException;
40 import java.io.InputStream;
41 import java.io.InputStreamReader;
42 import java.io.Reader;
43 import java.security.cert.CertificateException;
44 import java.security.cert.X509Certificate;
45 import java.util.Arrays;
46 import java.util.List;
47 import java.util.concurrent.Executor;
48 import javax.annotation.Nullable;
49 
50 /** Utility class for client and server ssl provider tests. */
51 public class CommonTlsContextTestsUtil {
52 
53   public static final String SERVER_0_PEM_FILE = "server0.pem";
54   public static final String SERVER_0_KEY_FILE = "server0.key";
55   public static final String SERVER_1_PEM_FILE = "server1.pem";
56   public static final String SERVER_1_KEY_FILE = "server1.key";
57   public static final String CLIENT_PEM_FILE = "client.pem";
58   public static final String CLIENT_KEY_FILE = "client.key";
59   public static final String CA_PEM_FILE = "ca.pem";
60   /** Bad/untrusted server certs. */
61   public static final String BAD_SERVER_PEM_FILE = "badserver.pem";
62   public static final String BAD_SERVER_KEY_FILE = "badserver.key";
63   public static final String BAD_CLIENT_PEM_FILE = "badclient.pem";
64   public static final String BAD_CLIENT_KEY_FILE = "badclient.key";
65 
66   /** takes additional values and creates CombinedCertificateValidationContext as needed. */
67   @SuppressWarnings("deprecation")
buildCommonTlsContextWithAdditionalValues( String certInstanceName, String certName, String validationContextCertInstanceName, String validationContextCertName, Iterable<StringMatcher> matchSubjectAltNames, Iterable<String> alpnNames)68   static CommonTlsContext buildCommonTlsContextWithAdditionalValues(
69       String certInstanceName, String certName,
70       String validationContextCertInstanceName, String validationContextCertName,
71       Iterable<StringMatcher> matchSubjectAltNames,
72       Iterable<String> alpnNames) {
73 
74     CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
75 
76     CertificateProviderInstance certificateProviderInstance = CertificateProviderInstance
77         .newBuilder().setInstanceName(certInstanceName).setCertificateName(certName).build();
78     if (certificateProviderInstance != null) {
79       builder.setTlsCertificateCertificateProviderInstance(certificateProviderInstance);
80     }
81     CertificateProviderInstance validationCertificateProviderInstance =
82         CertificateProviderInstance.newBuilder().setInstanceName(validationContextCertInstanceName)
83             .setCertificateName(validationContextCertName).build();
84     CertificateValidationContext certValidationContext =
85         matchSubjectAltNames == null
86             ? null
87             : CertificateValidationContext.newBuilder()
88                 .addAllMatchSubjectAltNames(matchSubjectAltNames)
89                 .build();
90     if (validationCertificateProviderInstance != null) {
91       CombinedCertificateValidationContext.Builder combinedBuilder =
92           CombinedCertificateValidationContext.newBuilder()
93               .setValidationContextCertificateProviderInstance(
94                   validationCertificateProviderInstance);
95       if (certValidationContext != null) {
96         combinedBuilder = combinedBuilder.setDefaultValidationContext(certValidationContext);
97       }
98       builder.setCombinedValidationContext(combinedBuilder);
99     } else if (validationCertificateProviderInstance != null) {
100       builder
101           .setValidationContextCertificateProviderInstance(validationCertificateProviderInstance);
102     } else if (certValidationContext != null) {
103       builder.setValidationContext(certValidationContext);
104     }
105     if (alpnNames != null) {
106       builder.addAllAlpnProtocols(alpnNames);
107     }
108     return builder.build();
109   }
110 
111   /** Helper method to build DownstreamTlsContext for multiple test classes. */
buildDownstreamTlsContext( CommonTlsContext commonTlsContext, boolean requireClientCert)112   static DownstreamTlsContext buildDownstreamTlsContext(
113       CommonTlsContext commonTlsContext, boolean requireClientCert) {
114     DownstreamTlsContext.Builder downstreamTlsContextBuilder =
115         DownstreamTlsContext.newBuilder()
116             .setRequireClientCertificate(BoolValue.of(requireClientCert));
117     if (commonTlsContext != null) {
118       downstreamTlsContextBuilder = downstreamTlsContextBuilder
119           .setCommonTlsContext(commonTlsContext);
120     }
121     return downstreamTlsContextBuilder.build();
122   }
123 
124   /** Helper method to build DownstreamTlsContext for multiple test classes. */
buildDownstreamTlsContext( String commonInstanceName, boolean hasRootCert, boolean requireClientCertificate)125   public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContext(
126       String commonInstanceName, boolean hasRootCert,
127       boolean requireClientCertificate) {
128     return buildDownstreamTlsContextForCertProviderInstance(
129         commonInstanceName,
130         "default",
131         hasRootCert ? commonInstanceName : null,
132         hasRootCert ? "ROOT" : null,
133         /* alpnProtocols= */ null,
134         /* staticCertValidationContext= */ null,
135         /* requireClientCert= */ requireClientCertificate);
136   }
137 
138   /** Helper method to build internal DownstreamTlsContext for multiple test classes. */
buildInternalDownstreamTlsContext( CommonTlsContext commonTlsContext, boolean requireClientCert)139   static EnvoyServerProtoData.DownstreamTlsContext buildInternalDownstreamTlsContext(
140       CommonTlsContext commonTlsContext, boolean requireClientCert) {
141     return EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext(
142         buildDownstreamTlsContext(commonTlsContext, requireClientCert));
143   }
144 
145   /** Helper method for creating DownstreamTlsContext values with names. */
buildTestDownstreamTlsContext( String certName, String validationContextCertName, boolean useSans)146   public static DownstreamTlsContext buildTestDownstreamTlsContext(
147       String certName, String validationContextCertName, boolean useSans) {
148     CommonTlsContext commonTlsContext = null;
149     if (certName != null || validationContextCertName != null || useSans) {
150       commonTlsContext = buildCommonTlsContextWithAdditionalValues(
151           "cert-instance-name", certName,
152           "cert-instance-name", validationContextCertName,
153           useSans ? Arrays.asList(
154               StringMatcher.newBuilder()
155                   .setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob")
156                   .build()) : null,
157           Arrays.asList("managed-tls"));
158     }
159     return buildDownstreamTlsContext(commonTlsContext, /* requireClientCert= */ false);
160   }
161 
buildTestInternalDownstreamTlsContext( String certName, String validationContextName)162   public static EnvoyServerProtoData.DownstreamTlsContext buildTestInternalDownstreamTlsContext(
163       String certName, String validationContextName) {
164     return EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext(
165         buildTestDownstreamTlsContext(certName, validationContextName, true));
166   }
167 
getTempFileNameForResourcesFile(String resFile)168   public static String getTempFileNameForResourcesFile(String resFile) throws IOException {
169     return TestUtils.loadCert(resFile).getAbsolutePath();
170   }
171 
172   /**
173    * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well.
174    */
buildUpstreamTlsContext( CommonTlsContext commonTlsContext)175   static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext(
176       CommonTlsContext commonTlsContext) {
177     UpstreamTlsContext upstreamTlsContext =
178         UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
179     return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext(
180         upstreamTlsContext);
181   }
182 
183   /** Helper method to build UpstreamTlsContext for multiple test classes. */
buildUpstreamTlsContext( String commonInstanceName, boolean hasIdentityCert)184   public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext(
185       String commonInstanceName, boolean hasIdentityCert) {
186     return buildUpstreamTlsContextForCertProviderInstance(
187         hasIdentityCert ? commonInstanceName : null,
188         hasIdentityCert ? "default" : null,
189         commonInstanceName,
190         "ROOT",
191         null,
192         null);
193   }
194 
195   /** Gets a cert from contents of a resource. */
getCertFromResourceName(String resourceName)196   public static X509Certificate getCertFromResourceName(String resourceName)
197       throws IOException, CertificateException {
198     try (ByteArrayInputStream bais =
199         new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8))) {
200       return CertificateUtils.toX509Certificate(bais);
201     }
202   }
203 
204   /** Gets contents of a certs resource. */
getResourceContents(String resourceName)205   public static String getResourceContents(String resourceName) throws IOException {
206     InputStream inputStream = TlsTesting.loadCert(resourceName);
207     String text = null;
208     try (Reader reader = new InputStreamReader(inputStream, UTF_8)) {
209       text = CharStreams.toString(reader);
210     }
211     return text;
212   }
213 
214   @SuppressWarnings("deprecation")
buildCommonTlsContextForCertProviderInstance( String certInstanceName, String certName, String rootInstanceName, String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext)215   private static CommonTlsContext buildCommonTlsContextForCertProviderInstance(
216       String certInstanceName,
217       String certName,
218       String rootInstanceName,
219       String rootCertName,
220       Iterable<String> alpnProtocols,
221       CertificateValidationContext staticCertValidationContext) {
222     CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
223     if (certInstanceName != null) {
224       builder =
225           builder.setTlsCertificateCertificateProviderInstance(
226               CommonTlsContext.CertificateProviderInstance.newBuilder()
227                   .setInstanceName(certInstanceName)
228                   .setCertificateName(certName));
229     }
230     builder =
231         addCertificateValidationContext(
232             builder, rootInstanceName, rootCertName, staticCertValidationContext);
233     if (alpnProtocols != null) {
234       builder.addAllAlpnProtocols(alpnProtocols);
235     }
236     return builder.build();
237   }
238 
buildNewCommonTlsContextForCertProviderInstance( String certInstanceName, String certName, String rootInstanceName, String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext)239   private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance(
240           String certInstanceName,
241           String certName,
242           String rootInstanceName,
243           String rootCertName,
244           Iterable<String> alpnProtocols,
245           CertificateValidationContext staticCertValidationContext) {
246     CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
247     if (certInstanceName != null) {
248       builder =
249               builder.setTlsCertificateProviderInstance(
250                       CertificateProviderPluginInstance.newBuilder()
251                               .setInstanceName(certInstanceName)
252                               .setCertificateName(certName));
253     }
254     builder =
255             addNewCertificateValidationContext(
256                     builder, rootInstanceName, rootCertName, staticCertValidationContext);
257     if (alpnProtocols != null) {
258       builder.addAllAlpnProtocols(alpnProtocols);
259     }
260     return builder.build();
261   }
262 
263   @SuppressWarnings("deprecation")
addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext)264   private static CommonTlsContext.Builder addCertificateValidationContext(
265       CommonTlsContext.Builder builder,
266       String rootInstanceName,
267       String rootCertName,
268       CertificateValidationContext staticCertValidationContext) {
269     if (rootInstanceName != null) {
270       CertificateProviderInstance providerInstance =
271           CertificateProviderInstance.newBuilder()
272               .setInstanceName(rootInstanceName)
273               .setCertificateName(rootCertName)
274               .build();
275       if (staticCertValidationContext != null) {
276         CombinedCertificateValidationContext combined =
277             CombinedCertificateValidationContext.newBuilder()
278                 .setDefaultValidationContext(staticCertValidationContext)
279                 .setValidationContextCertificateProviderInstance(providerInstance)
280                 .build();
281         return builder.setCombinedValidationContext(combined);
282       }
283       builder = builder.setValidationContextCertificateProviderInstance(providerInstance);
284     }
285     return builder;
286   }
287 
addNewCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext)288   private static CommonTlsContext.Builder addNewCertificateValidationContext(
289           CommonTlsContext.Builder builder,
290           String rootInstanceName,
291           String rootCertName,
292           CertificateValidationContext staticCertValidationContext) {
293     if (rootInstanceName != null) {
294       CertificateProviderPluginInstance providerInstance =
295           CertificateProviderPluginInstance.newBuilder()
296               .setInstanceName(rootInstanceName)
297               .setCertificateName(rootCertName)
298               .build();
299       CertificateValidationContext.Builder validationContextBuilder =
300           staticCertValidationContext != null ? staticCertValidationContext.toBuilder()
301               : CertificateValidationContext.newBuilder();
302       return builder.setValidationContext(
303           validationContextBuilder.setCaCertificateProviderInstance(providerInstance));
304     }
305     return builder;
306   }
307 
308   /** Helper method to build UpstreamTlsContext for CertProvider tests. */
309   public static EnvoyServerProtoData.UpstreamTlsContext
buildUpstreamTlsContextForCertProviderInstance( @ullable String certInstanceName, @Nullable String certName, @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext)310       buildUpstreamTlsContextForCertProviderInstance(
311           @Nullable String certInstanceName,
312           @Nullable String certName,
313           @Nullable String rootInstanceName,
314           @Nullable String rootCertName,
315           Iterable<String> alpnProtocols,
316           CertificateValidationContext staticCertValidationContext) {
317     return buildUpstreamTlsContext(
318         buildCommonTlsContextForCertProviderInstance(
319             certInstanceName,
320             certName,
321             rootInstanceName,
322             rootCertName,
323             alpnProtocols,
324             staticCertValidationContext));
325   }
326 
327   /** Helper method to build UpstreamTlsContext for CertProvider tests. */
328   public static EnvoyServerProtoData.UpstreamTlsContext
buildNewUpstreamTlsContextForCertProviderInstance( @ullable String certInstanceName, @Nullable String certName, @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext)329       buildNewUpstreamTlsContextForCertProviderInstance(
330           @Nullable String certInstanceName,
331           @Nullable String certName,
332           @Nullable String rootInstanceName,
333           @Nullable String rootCertName,
334           Iterable<String> alpnProtocols,
335           CertificateValidationContext staticCertValidationContext) {
336     return buildUpstreamTlsContext(
337         buildNewCommonTlsContextForCertProviderInstance(
338             certInstanceName,
339             certName,
340             rootInstanceName,
341             rootCertName,
342             alpnProtocols,
343             staticCertValidationContext));
344   }
345 
346   /** Helper method to build DownstreamTlsContext for CertProvider tests. */
347   public static EnvoyServerProtoData.DownstreamTlsContext
buildDownstreamTlsContextForCertProviderInstance( @ullable String certInstanceName, @Nullable String certName, @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext, boolean requireClientCert)348       buildDownstreamTlsContextForCertProviderInstance(
349           @Nullable String certInstanceName,
350           @Nullable String certName,
351           @Nullable String rootInstanceName,
352           @Nullable String rootCertName,
353           Iterable<String> alpnProtocols,
354           CertificateValidationContext staticCertValidationContext,
355           boolean requireClientCert) {
356     return buildInternalDownstreamTlsContext(
357         buildCommonTlsContextForCertProviderInstance(
358             certInstanceName,
359             certName,
360             rootInstanceName,
361             rootCertName,
362             alpnProtocols,
363             staticCertValidationContext), requireClientCert);
364   }
365 
366   /** Helper method to build DownstreamTlsContext for CertProvider tests. */
367   public static EnvoyServerProtoData.DownstreamTlsContext
buildNewDownstreamTlsContextForCertProviderInstance( @ullable String certInstanceName, @Nullable String certName, @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext, boolean requireClientCert)368       buildNewDownstreamTlsContextForCertProviderInstance(
369           @Nullable String certInstanceName,
370           @Nullable String certName,
371           @Nullable String rootInstanceName,
372           @Nullable String rootCertName,
373           Iterable<String> alpnProtocols,
374           CertificateValidationContext staticCertValidationContext,
375           boolean requireClientCert) {
376     return buildInternalDownstreamTlsContext(
377             buildNewCommonTlsContextForCertProviderInstance(
378                     certInstanceName,
379                     certName,
380                     rootInstanceName,
381                     rootCertName,
382                     alpnProtocols,
383                     staticCertValidationContext), requireClientCert);
384   }
385 
386   /** Perform some simple checks on sslContext. */
doChecksOnSslContext(boolean server, SslContext sslContext, List<String> expectedApnProtos)387   public static void doChecksOnSslContext(boolean server, SslContext sslContext,
388       List<String> expectedApnProtos) {
389     if (server) {
390       assertThat(sslContext.isServer()).isTrue();
391     } else {
392       assertThat(sslContext.isClient()).isTrue();
393     }
394     List<String> apnProtos = sslContext.applicationProtocolNegotiator().protocols();
395     assertThat(apnProtos).isNotNull();
396     if (expectedApnProtos != null) {
397       assertThat(apnProtos).isEqualTo(expectedApnProtos);
398     } else {
399       assertThat(apnProtos).contains("h2");
400     }
401   }
402 
403   /**
404    * Helper method to get the value thru directExecutor callback. Because of directExecutor this is
405    * a synchronous callback - so need to provide a listener.
406    */
getValueThruCallback(SslContextProvider provider)407   public static TestCallback getValueThruCallback(SslContextProvider provider) {
408     return getValueThruCallback(provider, MoreExecutors.directExecutor());
409   }
410 
411   /** Helper method to get the value thru callback with a user passed executor. */
getValueThruCallback(SslContextProvider provider, Executor executor)412   public static TestCallback getValueThruCallback(SslContextProvider provider, Executor executor) {
413     TestCallback testCallback = new TestCallback(executor);
414     provider.addCallback(testCallback);
415     return testCallback;
416   }
417 
418   public static class TestCallback extends SslContextProvider.Callback {
419 
420     public SslContext updatedSslContext;
421     public Throwable updatedThrowable;
422 
TestCallback(Executor executor)423     public TestCallback(Executor executor) {
424       super(executor);
425     }
426 
427     @Override
updateSslContext(SslContext sslContext)428     public void updateSslContext(SslContext sslContext) {
429       updatedSslContext = sslContext;
430     }
431 
432     @Override
onException(Throwable throwable)433     public void onException(Throwable throwable) {
434       updatedThrowable = throwable;
435     }
436   }
437 }
438