1 /* 2 * Copyright 2020 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.xds.internal.security; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static java.nio.charset.StandardCharsets.UTF_8; 21 22 import com.google.common.io.CharStreams; 23 import com.google.common.util.concurrent.MoreExecutors; 24 import com.google.protobuf.BoolValue; 25 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; 26 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; 27 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; 28 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; 29 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; 30 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext; 31 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; 32 import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; 33 import io.grpc.internal.testing.TestUtils; 34 import io.grpc.testing.TlsTesting; 35 import io.grpc.xds.EnvoyServerProtoData; 36 import io.grpc.xds.internal.security.trust.CertificateUtils; 37 import io.netty.handler.ssl.SslContext; 38 import java.io.ByteArrayInputStream; 39 import java.io.IOException; 40 import java.io.InputStream; 41 import java.io.InputStreamReader; 42 import java.io.Reader; 43 import java.security.cert.CertificateException; 44 import java.security.cert.X509Certificate; 45 import java.util.Arrays; 46 import java.util.List; 47 import java.util.concurrent.Executor; 48 import javax.annotation.Nullable; 49 50 /** Utility class for client and server ssl provider tests. */ 51 public class CommonTlsContextTestsUtil { 52 53 public static final String SERVER_0_PEM_FILE = "server0.pem"; 54 public static final String SERVER_0_KEY_FILE = "server0.key"; 55 public static final String SERVER_1_PEM_FILE = "server1.pem"; 56 public static final String SERVER_1_KEY_FILE = "server1.key"; 57 public static final String CLIENT_PEM_FILE = "client.pem"; 58 public static final String CLIENT_KEY_FILE = "client.key"; 59 public static final String CA_PEM_FILE = "ca.pem"; 60 /** Bad/untrusted server certs. */ 61 public static final String BAD_SERVER_PEM_FILE = "badserver.pem"; 62 public static final String BAD_SERVER_KEY_FILE = "badserver.key"; 63 public static final String BAD_CLIENT_PEM_FILE = "badclient.pem"; 64 public static final String BAD_CLIENT_KEY_FILE = "badclient.key"; 65 66 /** takes additional values and creates CombinedCertificateValidationContext as needed. */ 67 @SuppressWarnings("deprecation") buildCommonTlsContextWithAdditionalValues( String certInstanceName, String certName, String validationContextCertInstanceName, String validationContextCertName, Iterable<StringMatcher> matchSubjectAltNames, Iterable<String> alpnNames)68 static CommonTlsContext buildCommonTlsContextWithAdditionalValues( 69 String certInstanceName, String certName, 70 String validationContextCertInstanceName, String validationContextCertName, 71 Iterable<StringMatcher> matchSubjectAltNames, 72 Iterable<String> alpnNames) { 73 74 CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); 75 76 CertificateProviderInstance certificateProviderInstance = CertificateProviderInstance 77 .newBuilder().setInstanceName(certInstanceName).setCertificateName(certName).build(); 78 if (certificateProviderInstance != null) { 79 builder.setTlsCertificateCertificateProviderInstance(certificateProviderInstance); 80 } 81 CertificateProviderInstance validationCertificateProviderInstance = 82 CertificateProviderInstance.newBuilder().setInstanceName(validationContextCertInstanceName) 83 .setCertificateName(validationContextCertName).build(); 84 CertificateValidationContext certValidationContext = 85 matchSubjectAltNames == null 86 ? null 87 : CertificateValidationContext.newBuilder() 88 .addAllMatchSubjectAltNames(matchSubjectAltNames) 89 .build(); 90 if (validationCertificateProviderInstance != null) { 91 CombinedCertificateValidationContext.Builder combinedBuilder = 92 CombinedCertificateValidationContext.newBuilder() 93 .setValidationContextCertificateProviderInstance( 94 validationCertificateProviderInstance); 95 if (certValidationContext != null) { 96 combinedBuilder = combinedBuilder.setDefaultValidationContext(certValidationContext); 97 } 98 builder.setCombinedValidationContext(combinedBuilder); 99 } else if (validationCertificateProviderInstance != null) { 100 builder 101 .setValidationContextCertificateProviderInstance(validationCertificateProviderInstance); 102 } else if (certValidationContext != null) { 103 builder.setValidationContext(certValidationContext); 104 } 105 if (alpnNames != null) { 106 builder.addAllAlpnProtocols(alpnNames); 107 } 108 return builder.build(); 109 } 110 111 /** Helper method to build DownstreamTlsContext for multiple test classes. */ buildDownstreamTlsContext( CommonTlsContext commonTlsContext, boolean requireClientCert)112 static DownstreamTlsContext buildDownstreamTlsContext( 113 CommonTlsContext commonTlsContext, boolean requireClientCert) { 114 DownstreamTlsContext.Builder downstreamTlsContextBuilder = 115 DownstreamTlsContext.newBuilder() 116 .setRequireClientCertificate(BoolValue.of(requireClientCert)); 117 if (commonTlsContext != null) { 118 downstreamTlsContextBuilder = downstreamTlsContextBuilder 119 .setCommonTlsContext(commonTlsContext); 120 } 121 return downstreamTlsContextBuilder.build(); 122 } 123 124 /** Helper method to build DownstreamTlsContext for multiple test classes. */ buildDownstreamTlsContext( String commonInstanceName, boolean hasRootCert, boolean requireClientCertificate)125 public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContext( 126 String commonInstanceName, boolean hasRootCert, 127 boolean requireClientCertificate) { 128 return buildDownstreamTlsContextForCertProviderInstance( 129 commonInstanceName, 130 "default", 131 hasRootCert ? commonInstanceName : null, 132 hasRootCert ? "ROOT" : null, 133 /* alpnProtocols= */ null, 134 /* staticCertValidationContext= */ null, 135 /* requireClientCert= */ requireClientCertificate); 136 } 137 138 /** Helper method to build internal DownstreamTlsContext for multiple test classes. */ buildInternalDownstreamTlsContext( CommonTlsContext commonTlsContext, boolean requireClientCert)139 static EnvoyServerProtoData.DownstreamTlsContext buildInternalDownstreamTlsContext( 140 CommonTlsContext commonTlsContext, boolean requireClientCert) { 141 return EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( 142 buildDownstreamTlsContext(commonTlsContext, requireClientCert)); 143 } 144 145 /** Helper method for creating DownstreamTlsContext values with names. */ buildTestDownstreamTlsContext( String certName, String validationContextCertName, boolean useSans)146 public static DownstreamTlsContext buildTestDownstreamTlsContext( 147 String certName, String validationContextCertName, boolean useSans) { 148 CommonTlsContext commonTlsContext = null; 149 if (certName != null || validationContextCertName != null || useSans) { 150 commonTlsContext = buildCommonTlsContextWithAdditionalValues( 151 "cert-instance-name", certName, 152 "cert-instance-name", validationContextCertName, 153 useSans ? Arrays.asList( 154 StringMatcher.newBuilder() 155 .setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob") 156 .build()) : null, 157 Arrays.asList("managed-tls")); 158 } 159 return buildDownstreamTlsContext(commonTlsContext, /* requireClientCert= */ false); 160 } 161 buildTestInternalDownstreamTlsContext( String certName, String validationContextName)162 public static EnvoyServerProtoData.DownstreamTlsContext buildTestInternalDownstreamTlsContext( 163 String certName, String validationContextName) { 164 return EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( 165 buildTestDownstreamTlsContext(certName, validationContextName, true)); 166 } 167 getTempFileNameForResourcesFile(String resFile)168 public static String getTempFileNameForResourcesFile(String resFile) throws IOException { 169 return TestUtils.loadCert(resFile).getAbsolutePath(); 170 } 171 172 /** 173 * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. 174 */ buildUpstreamTlsContext( CommonTlsContext commonTlsContext)175 static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( 176 CommonTlsContext commonTlsContext) { 177 UpstreamTlsContext upstreamTlsContext = 178 UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); 179 return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext( 180 upstreamTlsContext); 181 } 182 183 /** Helper method to build UpstreamTlsContext for multiple test classes. */ buildUpstreamTlsContext( String commonInstanceName, boolean hasIdentityCert)184 public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( 185 String commonInstanceName, boolean hasIdentityCert) { 186 return buildUpstreamTlsContextForCertProviderInstance( 187 hasIdentityCert ? commonInstanceName : null, 188 hasIdentityCert ? "default" : null, 189 commonInstanceName, 190 "ROOT", 191 null, 192 null); 193 } 194 195 /** Gets a cert from contents of a resource. */ getCertFromResourceName(String resourceName)196 public static X509Certificate getCertFromResourceName(String resourceName) 197 throws IOException, CertificateException { 198 try (ByteArrayInputStream bais = 199 new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8))) { 200 return CertificateUtils.toX509Certificate(bais); 201 } 202 } 203 204 /** Gets contents of a certs resource. */ getResourceContents(String resourceName)205 public static String getResourceContents(String resourceName) throws IOException { 206 InputStream inputStream = TlsTesting.loadCert(resourceName); 207 String text = null; 208 try (Reader reader = new InputStreamReader(inputStream, UTF_8)) { 209 text = CharStreams.toString(reader); 210 } 211 return text; 212 } 213 214 @SuppressWarnings("deprecation") buildCommonTlsContextForCertProviderInstance( String certInstanceName, String certName, String rootInstanceName, String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext)215 private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( 216 String certInstanceName, 217 String certName, 218 String rootInstanceName, 219 String rootCertName, 220 Iterable<String> alpnProtocols, 221 CertificateValidationContext staticCertValidationContext) { 222 CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); 223 if (certInstanceName != null) { 224 builder = 225 builder.setTlsCertificateCertificateProviderInstance( 226 CommonTlsContext.CertificateProviderInstance.newBuilder() 227 .setInstanceName(certInstanceName) 228 .setCertificateName(certName)); 229 } 230 builder = 231 addCertificateValidationContext( 232 builder, rootInstanceName, rootCertName, staticCertValidationContext); 233 if (alpnProtocols != null) { 234 builder.addAllAlpnProtocols(alpnProtocols); 235 } 236 return builder.build(); 237 } 238 buildNewCommonTlsContextForCertProviderInstance( String certInstanceName, String certName, String rootInstanceName, String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext)239 private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance( 240 String certInstanceName, 241 String certName, 242 String rootInstanceName, 243 String rootCertName, 244 Iterable<String> alpnProtocols, 245 CertificateValidationContext staticCertValidationContext) { 246 CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); 247 if (certInstanceName != null) { 248 builder = 249 builder.setTlsCertificateProviderInstance( 250 CertificateProviderPluginInstance.newBuilder() 251 .setInstanceName(certInstanceName) 252 .setCertificateName(certName)); 253 } 254 builder = 255 addNewCertificateValidationContext( 256 builder, rootInstanceName, rootCertName, staticCertValidationContext); 257 if (alpnProtocols != null) { 258 builder.addAllAlpnProtocols(alpnProtocols); 259 } 260 return builder.build(); 261 } 262 263 @SuppressWarnings("deprecation") addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext)264 private static CommonTlsContext.Builder addCertificateValidationContext( 265 CommonTlsContext.Builder builder, 266 String rootInstanceName, 267 String rootCertName, 268 CertificateValidationContext staticCertValidationContext) { 269 if (rootInstanceName != null) { 270 CertificateProviderInstance providerInstance = 271 CertificateProviderInstance.newBuilder() 272 .setInstanceName(rootInstanceName) 273 .setCertificateName(rootCertName) 274 .build(); 275 if (staticCertValidationContext != null) { 276 CombinedCertificateValidationContext combined = 277 CombinedCertificateValidationContext.newBuilder() 278 .setDefaultValidationContext(staticCertValidationContext) 279 .setValidationContextCertificateProviderInstance(providerInstance) 280 .build(); 281 return builder.setCombinedValidationContext(combined); 282 } 283 builder = builder.setValidationContextCertificateProviderInstance(providerInstance); 284 } 285 return builder; 286 } 287 addNewCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext)288 private static CommonTlsContext.Builder addNewCertificateValidationContext( 289 CommonTlsContext.Builder builder, 290 String rootInstanceName, 291 String rootCertName, 292 CertificateValidationContext staticCertValidationContext) { 293 if (rootInstanceName != null) { 294 CertificateProviderPluginInstance providerInstance = 295 CertificateProviderPluginInstance.newBuilder() 296 .setInstanceName(rootInstanceName) 297 .setCertificateName(rootCertName) 298 .build(); 299 CertificateValidationContext.Builder validationContextBuilder = 300 staticCertValidationContext != null ? staticCertValidationContext.toBuilder() 301 : CertificateValidationContext.newBuilder(); 302 return builder.setValidationContext( 303 validationContextBuilder.setCaCertificateProviderInstance(providerInstance)); 304 } 305 return builder; 306 } 307 308 /** Helper method to build UpstreamTlsContext for CertProvider tests. */ 309 public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContextForCertProviderInstance( @ullable String certInstanceName, @Nullable String certName, @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext)310 buildUpstreamTlsContextForCertProviderInstance( 311 @Nullable String certInstanceName, 312 @Nullable String certName, 313 @Nullable String rootInstanceName, 314 @Nullable String rootCertName, 315 Iterable<String> alpnProtocols, 316 CertificateValidationContext staticCertValidationContext) { 317 return buildUpstreamTlsContext( 318 buildCommonTlsContextForCertProviderInstance( 319 certInstanceName, 320 certName, 321 rootInstanceName, 322 rootCertName, 323 alpnProtocols, 324 staticCertValidationContext)); 325 } 326 327 /** Helper method to build UpstreamTlsContext for CertProvider tests. */ 328 public static EnvoyServerProtoData.UpstreamTlsContext buildNewUpstreamTlsContextForCertProviderInstance( @ullable String certInstanceName, @Nullable String certName, @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext)329 buildNewUpstreamTlsContextForCertProviderInstance( 330 @Nullable String certInstanceName, 331 @Nullable String certName, 332 @Nullable String rootInstanceName, 333 @Nullable String rootCertName, 334 Iterable<String> alpnProtocols, 335 CertificateValidationContext staticCertValidationContext) { 336 return buildUpstreamTlsContext( 337 buildNewCommonTlsContextForCertProviderInstance( 338 certInstanceName, 339 certName, 340 rootInstanceName, 341 rootCertName, 342 alpnProtocols, 343 staticCertValidationContext)); 344 } 345 346 /** Helper method to build DownstreamTlsContext for CertProvider tests. */ 347 public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContextForCertProviderInstance( @ullable String certInstanceName, @Nullable String certName, @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext, boolean requireClientCert)348 buildDownstreamTlsContextForCertProviderInstance( 349 @Nullable String certInstanceName, 350 @Nullable String certName, 351 @Nullable String rootInstanceName, 352 @Nullable String rootCertName, 353 Iterable<String> alpnProtocols, 354 CertificateValidationContext staticCertValidationContext, 355 boolean requireClientCert) { 356 return buildInternalDownstreamTlsContext( 357 buildCommonTlsContextForCertProviderInstance( 358 certInstanceName, 359 certName, 360 rootInstanceName, 361 rootCertName, 362 alpnProtocols, 363 staticCertValidationContext), requireClientCert); 364 } 365 366 /** Helper method to build DownstreamTlsContext for CertProvider tests. */ 367 public static EnvoyServerProtoData.DownstreamTlsContext buildNewDownstreamTlsContextForCertProviderInstance( @ullable String certInstanceName, @Nullable String certName, @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable<String> alpnProtocols, CertificateValidationContext staticCertValidationContext, boolean requireClientCert)368 buildNewDownstreamTlsContextForCertProviderInstance( 369 @Nullable String certInstanceName, 370 @Nullable String certName, 371 @Nullable String rootInstanceName, 372 @Nullable String rootCertName, 373 Iterable<String> alpnProtocols, 374 CertificateValidationContext staticCertValidationContext, 375 boolean requireClientCert) { 376 return buildInternalDownstreamTlsContext( 377 buildNewCommonTlsContextForCertProviderInstance( 378 certInstanceName, 379 certName, 380 rootInstanceName, 381 rootCertName, 382 alpnProtocols, 383 staticCertValidationContext), requireClientCert); 384 } 385 386 /** Perform some simple checks on sslContext. */ doChecksOnSslContext(boolean server, SslContext sslContext, List<String> expectedApnProtos)387 public static void doChecksOnSslContext(boolean server, SslContext sslContext, 388 List<String> expectedApnProtos) { 389 if (server) { 390 assertThat(sslContext.isServer()).isTrue(); 391 } else { 392 assertThat(sslContext.isClient()).isTrue(); 393 } 394 List<String> apnProtos = sslContext.applicationProtocolNegotiator().protocols(); 395 assertThat(apnProtos).isNotNull(); 396 if (expectedApnProtos != null) { 397 assertThat(apnProtos).isEqualTo(expectedApnProtos); 398 } else { 399 assertThat(apnProtos).contains("h2"); 400 } 401 } 402 403 /** 404 * Helper method to get the value thru directExecutor callback. Because of directExecutor this is 405 * a synchronous callback - so need to provide a listener. 406 */ getValueThruCallback(SslContextProvider provider)407 public static TestCallback getValueThruCallback(SslContextProvider provider) { 408 return getValueThruCallback(provider, MoreExecutors.directExecutor()); 409 } 410 411 /** Helper method to get the value thru callback with a user passed executor. */ getValueThruCallback(SslContextProvider provider, Executor executor)412 public static TestCallback getValueThruCallback(SslContextProvider provider, Executor executor) { 413 TestCallback testCallback = new TestCallback(executor); 414 provider.addCallback(testCallback); 415 return testCallback; 416 } 417 418 public static class TestCallback extends SslContextProvider.Callback { 419 420 public SslContext updatedSslContext; 421 public Throwable updatedThrowable; 422 TestCallback(Executor executor)423 public TestCallback(Executor executor) { 424 super(executor); 425 } 426 427 @Override updateSslContext(SslContext sslContext)428 public void updateSslContext(SslContext sslContext) { 429 updatedSslContext = sslContext; 430 } 431 432 @Override onException(Throwable throwable)433 public void onException(Throwable throwable) { 434 updatedThrowable = throwable; 435 } 436 } 437 } 438