xref: /aosp_15_r20/external/grpc-grpc/src/core/tsi/ssl_transport_security.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/ssl_transport_security.h"
22 
23 #include <limits.h>
24 #include <string.h>
25 
26 // TODO(jboeuf): refactor inet_ntop into a portability header.
27 // Note: for whomever reads this and tries to refactor this, this
28 // can't be in grpc, it has to be in gpr.
29 #ifdef GPR_WINDOWS
30 #include <ws2tcpip.h>
31 #else
32 #include <arpa/inet.h>
33 #include <sys/socket.h>
34 #endif
35 
36 #include <memory>
37 #include <string>
38 
39 #include <openssl/bio.h>
40 #include <openssl/crypto.h>  // For OPENSSL_free
41 #include <openssl/engine.h>
42 #include <openssl/err.h>
43 #include <openssl/ssl.h>
44 #include <openssl/tls1.h>
45 #include <openssl/x509.h>
46 #include <openssl/x509v3.h>
47 
48 #include "absl/strings/match.h"
49 #include "absl/strings/str_cat.h"
50 #include "absl/strings/string_view.h"
51 
52 #include <grpc/grpc_crl_provider.h>
53 #include <grpc/grpc_security.h>
54 #include <grpc/support/alloc.h>
55 #include <grpc/support/log.h>
56 #include <grpc/support/string_util.h>
57 #include <grpc/support/sync.h>
58 #include <grpc/support/thd_id.h>
59 
60 #include "src/core/lib/gpr/useful.h"
61 #include "src/core/lib/gprpp/crash.h"
62 #include "src/core/lib/security/credentials/tls/grpc_tls_crl_provider.h"
63 #include "src/core/tsi/ssl/key_logging/ssl_key_logging.h"
64 #include "src/core/tsi/ssl/session_cache/ssl_session_cache.h"
65 #include "src/core/tsi/ssl_transport_security_utils.h"
66 #include "src/core/tsi/ssl_types.h"
67 #include "src/core/tsi/transport_security.h"
68 
69 // --- Constants. ---
70 
71 #define TSI_SSL_MAX_BIO_WRITE_ATTEMPTS 100
72 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND 16384
73 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND 1024
74 #define TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 1024
75 const size_t kMaxChainLength = 100;
76 
77 // Putting a macro like this and littering the source file with #if is really
78 // bad practice.
79 // TODO(jboeuf): refactor all the #if / #endif in a separate module.
80 #ifndef TSI_OPENSSL_ALPN_SUPPORT
81 #define TSI_OPENSSL_ALPN_SUPPORT 1
82 #endif
83 
84 // TODO(jboeuf): I have not found a way to get this number dynamically from the
85 // SSL structure. This is what we would ultimately want though...
86 #define TSI_SSL_MAX_PROTECTION_OVERHEAD 100
87 
88 using TlsSessionKeyLogger = tsi::TlsSessionKeyLoggerCache::TlsSessionKeyLogger;
89 
90 // --- Structure definitions. ---
91 
92 struct tsi_ssl_root_certs_store {
93   X509_STORE* store;
94 };
95 
96 struct tsi_ssl_handshaker_factory {
97   const tsi_ssl_handshaker_factory_vtable* vtable;
98   gpr_refcount refcount;
99 };
100 
101 struct tsi_ssl_client_handshaker_factory {
102   tsi_ssl_handshaker_factory base;
103   SSL_CTX* ssl_context;
104   unsigned char* alpn_protocol_list;
105   size_t alpn_protocol_list_length;
106   grpc_core::RefCountedPtr<tsi::SslSessionLRUCache> session_cache;
107   grpc_core::RefCountedPtr<TlsSessionKeyLogger> key_logger;
108 };
109 
110 struct tsi_ssl_server_handshaker_factory {
111   // Several contexts to support SNI.
112   // The tsi_peer array contains the subject names of the server certificates
113   // associated with the contexts at the same index.
114   tsi_ssl_handshaker_factory base;
115   SSL_CTX** ssl_contexts;
116   tsi_peer* ssl_context_x509_subject_names;
117   size_t ssl_context_count;
118   unsigned char* alpn_protocol_list;
119   size_t alpn_protocol_list_length;
120   grpc_core::RefCountedPtr<TlsSessionKeyLogger> key_logger;
121 };
122 
123 struct tsi_ssl_handshaker {
124   tsi_handshaker base;
125   SSL* ssl;
126   BIO* network_io;
127   tsi_result result;
128   unsigned char* outgoing_bytes_buffer;
129   size_t outgoing_bytes_buffer_size;
130   tsi_ssl_handshaker_factory* factory_ref;
131 };
132 struct tsi_ssl_handshaker_result {
133   tsi_handshaker_result base;
134   SSL* ssl;
135   BIO* network_io;
136   unsigned char* unused_bytes;
137   size_t unused_bytes_size;
138 };
139 struct tsi_ssl_frame_protector {
140   tsi_frame_protector base;
141   SSL* ssl;
142   BIO* network_io;
143   unsigned char* buffer;
144   size_t buffer_size;
145   size_t buffer_offset;
146 };
147 // --- Library Initialization. ---
148 
149 static gpr_once g_init_openssl_once = GPR_ONCE_INIT;
150 static int g_ssl_ctx_ex_factory_index = -1;
151 static int g_ssl_ctx_ex_crl_provider_index = -1;
152 static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'};
153 static int g_ssl_ex_verified_root_cert_index = -1;
154 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
155 static const char kSslEnginePrefix[] = "engine:";
156 #endif
157 #if OPENSSL_VERSION_NUMBER >= 0x30000000
158 static const int kSslEcCurveNames[] = {NID_X9_62_prime256v1};
159 #endif
160 
161 #if OPENSSL_VERSION_NUMBER < 0x10100000
162 static gpr_mu* g_openssl_mutexes = nullptr;
163 static void openssl_locking_cb(int mode, int type, const char* file,
164                                int line) GRPC_UNUSED;
165 static unsigned long openssl_thread_id_cb(void) GRPC_UNUSED;
166 
openssl_locking_cb(int mode,int type,const char * file,int line)167 static void openssl_locking_cb(int mode, int type, const char* file, int line) {
168   if (mode & CRYPTO_LOCK) {
169     gpr_mu_lock(&g_openssl_mutexes[type]);
170   } else {
171     gpr_mu_unlock(&g_openssl_mutexes[type]);
172   }
173 }
174 
openssl_thread_id_cb(void)175 static unsigned long openssl_thread_id_cb(void) {
176   return static_cast<unsigned long>(gpr_thd_currentid());
177 }
178 #endif
179 
verified_root_cert_free(void *,void * ptr,CRYPTO_EX_DATA *,int,long,void *)180 static void verified_root_cert_free(void* /*parent*/, void* ptr,
181                                     CRYPTO_EX_DATA* /*ad*/, int /*index*/,
182                                     long /*argl*/, void* /*argp*/) {
183   X509_free(static_cast<X509*>(ptr));
184 }
185 
init_openssl(void)186 static void init_openssl(void) {
187 #if OPENSSL_VERSION_NUMBER >= 0x10100000
188   OPENSSL_init_ssl(0, nullptr);
189 #else
190   SSL_library_init();
191   SSL_load_error_strings();
192   OpenSSL_add_all_algorithms();
193 #endif
194 #if OPENSSL_VERSION_NUMBER < 0x10100000
195   if (!CRYPTO_get_locking_callback()) {
196     int num_locks = CRYPTO_num_locks();
197     GPR_ASSERT(num_locks > 0);
198     g_openssl_mutexes = static_cast<gpr_mu*>(
199         gpr_malloc(static_cast<size_t>(num_locks) * sizeof(gpr_mu)));
200     for (int i = 0; i < num_locks; i++) {
201       gpr_mu_init(&g_openssl_mutexes[i]);
202     }
203     CRYPTO_set_locking_callback(openssl_locking_cb);
204     CRYPTO_set_id_callback(openssl_thread_id_cb);
205   } else {
206     gpr_log(GPR_INFO, "OpenSSL callback has already been set.");
207   }
208 #endif
209   g_ssl_ctx_ex_factory_index =
210       SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
211   GPR_ASSERT(g_ssl_ctx_ex_factory_index != -1);
212 
213   g_ssl_ctx_ex_crl_provider_index =
214       SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
215   GPR_ASSERT(g_ssl_ctx_ex_crl_provider_index != -1);
216 
217   g_ssl_ex_verified_root_cert_index = SSL_get_ex_new_index(
218       0, nullptr, nullptr, nullptr, verified_root_cert_free);
219   GPR_ASSERT(g_ssl_ex_verified_root_cert_index != -1);
220 }
221 
222 // --- Ssl utils. ---
223 
224 // TODO(jboeuf): Remove when we are past the debugging phase with this code.
ssl_log_where_info(const SSL * ssl,int where,int flag,const char * msg)225 static void ssl_log_where_info(const SSL* ssl, int where, int flag,
226                                const char* msg) {
227   if ((where & flag) && GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) {
228     gpr_log(GPR_INFO, "%20.20s - %30.30s  - %5.10s", msg,
229             SSL_state_string_long(ssl), SSL_state_string(ssl));
230   }
231 }
232 
233 // Used for debugging. TODO(jboeuf): Remove when code is mature enough.
ssl_info_callback(const SSL * ssl,int where,int ret)234 static void ssl_info_callback(const SSL* ssl, int where, int ret) {
235   if (ret == 0) {
236     gpr_log(GPR_ERROR, "ssl_info_callback: error occurred.\n");
237     return;
238   }
239 
240   ssl_log_where_info(ssl, where, SSL_CB_LOOP, "LOOP");
241   ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_START, "HANDSHAKE START");
242   ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_DONE, "HANDSHAKE DONE");
243 }
244 
245 // Returns 1 if name looks like an IP address, 0 otherwise.
246 // This is a very rough heuristic, and only handles IPv6 in hexadecimal form.
looks_like_ip_address(absl::string_view name)247 static int looks_like_ip_address(absl::string_view name) {
248   size_t dot_count = 0;
249   size_t num_size = 0;
250   for (size_t i = 0; i < name.size(); ++i) {
251     if (name[i] == ':') {
252       // IPv6 Address in hexadecimal form, : is not allowed in DNS names.
253       return 1;
254     }
255     if (name[i] >= '0' && name[i] <= '9') {
256       if (num_size > 3) return 0;
257       num_size++;
258     } else if (name[i] == '.') {
259       if (dot_count > 3 || num_size == 0) return 0;
260       dot_count++;
261       num_size = 0;
262     } else {
263       return 0;
264     }
265   }
266   if (dot_count < 3 || num_size == 0) return 0;
267   return 1;
268 }
269 
270 // Gets the subject CN from an X509 cert.
ssl_get_x509_common_name(X509 * cert,unsigned char ** utf8,size_t * utf8_size)271 static tsi_result ssl_get_x509_common_name(X509* cert, unsigned char** utf8,
272                                            size_t* utf8_size) {
273   int common_name_index = -1;
274   X509_NAME_ENTRY* common_name_entry = nullptr;
275   ASN1_STRING* common_name_asn1 = nullptr;
276   X509_NAME* subject_name = X509_get_subject_name(cert);
277   int utf8_returned_size = 0;
278   if (subject_name == nullptr) {
279     gpr_log(GPR_DEBUG, "Could not get subject name from certificate.");
280     return TSI_NOT_FOUND;
281   }
282   common_name_index =
283       X509_NAME_get_index_by_NID(subject_name, NID_commonName, -1);
284   if (common_name_index == -1) {
285     gpr_log(GPR_DEBUG,
286             "Could not get common name of subject from certificate.");
287     return TSI_NOT_FOUND;
288   }
289   common_name_entry = X509_NAME_get_entry(subject_name, common_name_index);
290   if (common_name_entry == nullptr) {
291     gpr_log(GPR_ERROR, "Could not get common name entry from certificate.");
292     return TSI_INTERNAL_ERROR;
293   }
294   common_name_asn1 = X509_NAME_ENTRY_get_data(common_name_entry);
295   if (common_name_asn1 == nullptr) {
296     gpr_log(GPR_ERROR,
297             "Could not get common name entry asn1 from certificate.");
298     return TSI_INTERNAL_ERROR;
299   }
300   utf8_returned_size = ASN1_STRING_to_UTF8(utf8, common_name_asn1);
301   if (utf8_returned_size < 0) {
302     gpr_log(GPR_ERROR, "Could not extract utf8 from asn1 string.");
303     return TSI_OUT_OF_RESOURCES;
304   }
305   *utf8_size = static_cast<size_t>(utf8_returned_size);
306   return TSI_OK;
307 }
308 
309 // Gets the subject CN of an X509 cert as a tsi_peer_property.
peer_property_from_x509_common_name(X509 * cert,tsi_peer_property * property)310 static tsi_result peer_property_from_x509_common_name(
311     X509* cert, tsi_peer_property* property) {
312   unsigned char* common_name;
313   size_t common_name_size;
314   tsi_result result =
315       ssl_get_x509_common_name(cert, &common_name, &common_name_size);
316   if (result != TSI_OK) {
317     if (result == TSI_NOT_FOUND) {
318       common_name = nullptr;
319       common_name_size = 0;
320     } else {
321       return result;
322     }
323   }
324   result = tsi_construct_string_peer_property(
325       TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY,
326       common_name == nullptr ? "" : reinterpret_cast<const char*>(common_name),
327       common_name_size, property);
328   OPENSSL_free(common_name);
329   return result;
330 }
331 
332 // Gets the subject of an X509 cert as a tsi_peer_property.
peer_property_from_x509_subject(X509 * cert,tsi_peer_property * property,bool is_verified_root_cert)333 static tsi_result peer_property_from_x509_subject(X509* cert,
334                                                   tsi_peer_property* property,
335                                                   bool is_verified_root_cert) {
336   X509_NAME* subject_name = X509_get_subject_name(cert);
337   if (subject_name == nullptr) {
338     gpr_log(GPR_INFO, "Could not get subject name from certificate.");
339     return TSI_NOT_FOUND;
340   }
341   BIO* bio = BIO_new(BIO_s_mem());
342   X509_NAME_print_ex(bio, subject_name, 0, XN_FLAG_RFC2253);
343   char* contents;
344   long len = BIO_get_mem_data(bio, &contents);
345   if (len < 0) {
346     gpr_log(GPR_ERROR, "Could not get subject entry from certificate.");
347     BIO_free(bio);
348     return TSI_INTERNAL_ERROR;
349   }
350   tsi_result result;
351   if (!is_verified_root_cert) {
352     result = tsi_construct_string_peer_property(
353         TSI_X509_SUBJECT_PEER_PROPERTY, contents, static_cast<size_t>(len),
354         property);
355   } else {
356     result = tsi_construct_string_peer_property(
357         TSI_X509_VERIFIED_ROOT_CERT_SUBECT_PEER_PROPERTY, contents,
358         static_cast<size_t>(len), property);
359   }
360   BIO_free(bio);
361   return result;
362 }
363 
364 // Gets the X509 cert in PEM format as a tsi_peer_property.
add_pem_certificate(X509 * cert,tsi_peer_property * property)365 static tsi_result add_pem_certificate(X509* cert, tsi_peer_property* property) {
366   BIO* bio = BIO_new(BIO_s_mem());
367   if (!PEM_write_bio_X509(bio, cert)) {
368     BIO_free(bio);
369     return TSI_INTERNAL_ERROR;
370   }
371   char* contents;
372   long len = BIO_get_mem_data(bio, &contents);
373   if (len <= 0) {
374     BIO_free(bio);
375     return TSI_INTERNAL_ERROR;
376   }
377   tsi_result result = tsi_construct_string_peer_property(
378       TSI_X509_PEM_CERT_PROPERTY, contents, static_cast<size_t>(len), property);
379   BIO_free(bio);
380   return result;
381 }
382 
383 // Gets the subject SANs from an X509 cert as a tsi_peer_property.
add_subject_alt_names_properties_to_peer(tsi_peer * peer,GENERAL_NAMES * subject_alt_names,size_t subject_alt_name_count,int * current_insert_index)384 static tsi_result add_subject_alt_names_properties_to_peer(
385     tsi_peer* peer, GENERAL_NAMES* subject_alt_names,
386     size_t subject_alt_name_count, int* current_insert_index) {
387   size_t i;
388   tsi_result result = TSI_OK;
389 
390   for (i = 0; i < subject_alt_name_count; i++) {
391     GENERAL_NAME* subject_alt_name =
392         sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
393     if (subject_alt_name->type == GEN_DNS ||
394         subject_alt_name->type == GEN_EMAIL ||
395         subject_alt_name->type == GEN_URI) {
396       unsigned char* name = nullptr;
397       int name_size;
398       std::string property_name;
399       if (subject_alt_name->type == GEN_DNS) {
400         name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.dNSName);
401         property_name = TSI_X509_DNS_PEER_PROPERTY;
402       } else if (subject_alt_name->type == GEN_EMAIL) {
403         name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.rfc822Name);
404         property_name = TSI_X509_EMAIL_PEER_PROPERTY;
405       } else {
406         name_size = ASN1_STRING_to_UTF8(
407             &name, subject_alt_name->d.uniformResourceIdentifier);
408         property_name = TSI_X509_URI_PEER_PROPERTY;
409       }
410       if (name_size < 0) {
411         gpr_log(GPR_ERROR, "Could not get utf8 from asn1 string.");
412         result = TSI_INTERNAL_ERROR;
413         break;
414       }
415       result = tsi_construct_string_peer_property(
416           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY,
417           reinterpret_cast<const char*>(name), static_cast<size_t>(name_size),
418           &peer->properties[(*current_insert_index)++]);
419       if (result != TSI_OK) {
420         OPENSSL_free(name);
421         break;
422       }
423       result = tsi_construct_string_peer_property(
424           property_name.c_str(), reinterpret_cast<const char*>(name),
425           static_cast<size_t>(name_size),
426           &peer->properties[(*current_insert_index)++]);
427       OPENSSL_free(name);
428     } else if (subject_alt_name->type == GEN_IPADD) {
429       char ntop_buf[INET6_ADDRSTRLEN];
430       int af;
431 
432       if (subject_alt_name->d.iPAddress->length == 4) {
433         af = AF_INET;
434       } else if (subject_alt_name->d.iPAddress->length == 16) {
435         af = AF_INET6;
436       } else {
437         gpr_log(GPR_ERROR, "SAN IP Address contained invalid IP");
438         result = TSI_INTERNAL_ERROR;
439         break;
440       }
441       const char* name = inet_ntop(af, subject_alt_name->d.iPAddress->data,
442                                    ntop_buf, INET6_ADDRSTRLEN);
443       if (name == nullptr) {
444         gpr_log(GPR_ERROR, "Could not get IP string from asn1 octet.");
445         result = TSI_INTERNAL_ERROR;
446         break;
447       }
448 
449       result = tsi_construct_string_peer_property_from_cstring(
450           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, name,
451           &peer->properties[(*current_insert_index)++]);
452       if (result != TSI_OK) break;
453       result = tsi_construct_string_peer_property_from_cstring(
454           TSI_X509_IP_PEER_PROPERTY, name,
455           &peer->properties[(*current_insert_index)++]);
456     } else {
457       result = tsi_construct_string_peer_property_from_cstring(
458           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, "other types of SAN",
459           &peer->properties[(*current_insert_index)++]);
460     }
461     if (result != TSI_OK) break;
462   }
463   return result;
464 }
465 
466 // Gets information about the peer's X509 cert as a tsi_peer object.
peer_from_x509(X509 * cert,int include_certificate_type,tsi_peer * peer)467 static tsi_result peer_from_x509(X509* cert, int include_certificate_type,
468                                  tsi_peer* peer) {
469   // TODO(jboeuf): Maybe add more properties.
470   GENERAL_NAMES* subject_alt_names = static_cast<GENERAL_NAMES*>(
471       X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
472   int subject_alt_name_count =
473       (subject_alt_names != nullptr)
474           ? static_cast<int>(sk_GENERAL_NAME_num(subject_alt_names))
475           : 0;
476   size_t property_count;
477   tsi_result result;
478   GPR_ASSERT(subject_alt_name_count >= 0);
479   property_count = (include_certificate_type ? size_t{1} : 0) +
480                    3 /* subject, common name, certificate */ +
481                    static_cast<size_t>(subject_alt_name_count);
482   for (int i = 0; i < subject_alt_name_count; i++) {
483     GENERAL_NAME* subject_alt_name =
484         sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
485     // TODO(zhenlian): Clean up tsi_peer to avoid duplicate entries.
486     // URI, DNS, email and ip address SAN fields are plumbed to tsi_peer, in
487     // addition to all SAN fields (results in duplicate values). This code
488     // snippet updates property_count accordingly.
489     if (subject_alt_name->type == GEN_URI ||
490         subject_alt_name->type == GEN_DNS ||
491         subject_alt_name->type == GEN_EMAIL ||
492         subject_alt_name->type == GEN_IPADD) {
493       property_count += 1;
494     }
495   }
496   result = tsi_construct_peer(property_count, peer);
497   if (result != TSI_OK) return result;
498   int current_insert_index = 0;
499   do {
500     if (include_certificate_type) {
501       result = tsi_construct_string_peer_property_from_cstring(
502           TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE,
503           &peer->properties[current_insert_index++]);
504       if (result != TSI_OK) break;
505     }
506 
507     result = peer_property_from_x509_subject(
508         cert, &peer->properties[current_insert_index++],
509         /*is_verified_root_cert=*/false);
510     if (result != TSI_OK) break;
511 
512     result = peer_property_from_x509_common_name(
513         cert, &peer->properties[current_insert_index++]);
514     if (result != TSI_OK) break;
515 
516     result =
517         add_pem_certificate(cert, &peer->properties[current_insert_index++]);
518     if (result != TSI_OK) break;
519 
520     if (subject_alt_name_count != 0) {
521       result = add_subject_alt_names_properties_to_peer(
522           peer, subject_alt_names, static_cast<size_t>(subject_alt_name_count),
523           &current_insert_index);
524       if (result != TSI_OK) break;
525     }
526   } while (false);
527 
528   if (subject_alt_names != nullptr) {
529     sk_GENERAL_NAME_pop_free(subject_alt_names, GENERAL_NAME_free);
530   }
531   if (result != TSI_OK) tsi_peer_destruct(peer);
532 
533   GPR_ASSERT((int)peer->property_count == current_insert_index);
534   return result;
535 }
536 
537 // Loads an in-memory PEM certificate chain into the SSL context.
ssl_ctx_use_certificate_chain(SSL_CTX * context,const char * pem_cert_chain,size_t pem_cert_chain_size)538 static tsi_result ssl_ctx_use_certificate_chain(SSL_CTX* context,
539                                                 const char* pem_cert_chain,
540                                                 size_t pem_cert_chain_size) {
541   tsi_result result = TSI_OK;
542   X509* certificate = nullptr;
543   BIO* pem;
544   GPR_ASSERT(pem_cert_chain_size <= INT_MAX);
545   pem = BIO_new_mem_buf(pem_cert_chain, static_cast<int>(pem_cert_chain_size));
546   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
547 
548   do {
549     certificate =
550         PEM_read_bio_X509_AUX(pem, nullptr, nullptr, const_cast<char*>(""));
551     if (certificate == nullptr) {
552       result = TSI_INVALID_ARGUMENT;
553       break;
554     }
555     if (!SSL_CTX_use_certificate(context, certificate)) {
556       result = TSI_INVALID_ARGUMENT;
557       break;
558     }
559     while (true) {
560       X509* certificate_authority =
561           PEM_read_bio_X509(pem, nullptr, nullptr, const_cast<char*>(""));
562       if (certificate_authority == nullptr) {
563         ERR_clear_error();
564         break;  // Done reading.
565       }
566       if (!SSL_CTX_add_extra_chain_cert(context, certificate_authority)) {
567         X509_free(certificate_authority);
568         result = TSI_INVALID_ARGUMENT;
569         break;
570       }
571       // We don't need to free certificate_authority as its ownership has been
572       // transferred to the context. That is not the case for certificate
573       // though.
574       //
575     }
576   } while (false);
577 
578   if (certificate != nullptr) X509_free(certificate);
579   BIO_free(pem);
580   return result;
581 }
582 
583 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
ssl_ctx_use_engine_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)584 static tsi_result ssl_ctx_use_engine_private_key(SSL_CTX* context,
585                                                  const char* pem_key,
586                                                  size_t pem_key_size) {
587   tsi_result result = TSI_OK;
588   EVP_PKEY* private_key = nullptr;
589   ENGINE* engine = nullptr;
590   char* engine_name = nullptr;
591   // Parse key which is in following format engine:<engine_id>:<key_id>
592   do {
593     char* engine_start = (char*)pem_key + strlen(kSslEnginePrefix);
594     char* engine_end = (char*)strchr(engine_start, ':');
595     if (engine_end == nullptr) {
596       result = TSI_INVALID_ARGUMENT;
597       break;
598     }
599     char* key_id = engine_end + 1;
600     int engine_name_length = engine_end - engine_start;
601     if (engine_name_length == 0) {
602       result = TSI_INVALID_ARGUMENT;
603       break;
604     }
605     engine_name = static_cast<char*>(gpr_zalloc(engine_name_length + 1));
606     memcpy(engine_name, engine_start, engine_name_length);
607     gpr_log(GPR_DEBUG, "ENGINE key: %s", engine_name);
608     ENGINE_load_dynamic();
609     engine = ENGINE_by_id(engine_name);
610     if (engine == nullptr) {
611       // If not available at ENGINE_DIR, use dynamic to load from
612       // current working directory.
613       engine = ENGINE_by_id("dynamic");
614       if (engine == nullptr) {
615         gpr_log(GPR_ERROR, "Cannot load dynamic engine");
616         result = TSI_INVALID_ARGUMENT;
617         break;
618       }
619       if (!ENGINE_ctrl_cmd_string(engine, "ID", engine_name, 0) ||
620           !ENGINE_ctrl_cmd_string(engine, "DIR_LOAD", "2", 0) ||
621           !ENGINE_ctrl_cmd_string(engine, "DIR_ADD", ".", 0) ||
622           !ENGINE_ctrl_cmd_string(engine, "LIST_ADD", "1", 0) ||
623           !ENGINE_ctrl_cmd_string(engine, "LOAD", NULL, 0)) {
624         gpr_log(GPR_ERROR, "Cannot find engine");
625         result = TSI_INVALID_ARGUMENT;
626         break;
627       }
628     }
629     if (!ENGINE_set_default(engine, ENGINE_METHOD_ALL)) {
630       gpr_log(GPR_ERROR, "ENGINE_set_default with ENGINE_METHOD_ALL failed");
631       result = TSI_INVALID_ARGUMENT;
632       break;
633     }
634     if (!ENGINE_init(engine)) {
635       gpr_log(GPR_ERROR, "ENGINE_init failed");
636       result = TSI_INVALID_ARGUMENT;
637       break;
638     }
639     private_key = ENGINE_load_private_key(engine, key_id, 0, 0);
640     if (private_key == nullptr) {
641       gpr_log(GPR_ERROR, "ENGINE_load_private_key failed");
642       result = TSI_INVALID_ARGUMENT;
643       break;
644     }
645     if (!SSL_CTX_use_PrivateKey(context, private_key)) {
646       gpr_log(GPR_ERROR, "SSL_CTX_use_PrivateKey failed");
647       result = TSI_INVALID_ARGUMENT;
648       break;
649     }
650   } while (0);
651   if (engine != nullptr) ENGINE_free(engine);
652   if (private_key != nullptr) EVP_PKEY_free(private_key);
653   if (engine_name != nullptr) gpr_free(engine_name);
654   return result;
655 }
656 #endif  // !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
657 
ssl_ctx_use_pem_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)658 static tsi_result ssl_ctx_use_pem_private_key(SSL_CTX* context,
659                                               const char* pem_key,
660                                               size_t pem_key_size) {
661   tsi_result result = TSI_OK;
662   EVP_PKEY* private_key = nullptr;
663   BIO* pem;
664   GPR_ASSERT(pem_key_size <= INT_MAX);
665   pem = BIO_new_mem_buf(pem_key, static_cast<int>(pem_key_size));
666   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
667   do {
668     private_key =
669         PEM_read_bio_PrivateKey(pem, nullptr, nullptr, const_cast<char*>(""));
670     if (private_key == nullptr) {
671       result = TSI_INVALID_ARGUMENT;
672       break;
673     }
674     if (!SSL_CTX_use_PrivateKey(context, private_key)) {
675       result = TSI_INVALID_ARGUMENT;
676       break;
677     }
678   } while (false);
679   if (private_key != nullptr) EVP_PKEY_free(private_key);
680   BIO_free(pem);
681   return result;
682 }
683 
684 // Loads an in-memory PEM private key into the SSL context.
ssl_ctx_use_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)685 static tsi_result ssl_ctx_use_private_key(SSL_CTX* context, const char* pem_key,
686                                           size_t pem_key_size) {
687 // BoringSSL does not have ENGINE support
688 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
689   if (strncmp(pem_key, kSslEnginePrefix, strlen(kSslEnginePrefix)) == 0) {
690     return ssl_ctx_use_engine_private_key(context, pem_key, pem_key_size);
691   } else
692 #endif  // !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
693   {
694     return ssl_ctx_use_pem_private_key(context, pem_key, pem_key_size);
695   }
696 }
697 
698 // Loads in-memory PEM verification certs into the SSL context and optionally
699 // returns the verification cert names (root_names can be NULL).
x509_store_load_certs(X509_STORE * cert_store,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_names)700 static tsi_result x509_store_load_certs(X509_STORE* cert_store,
701                                         const char* pem_roots,
702                                         size_t pem_roots_size,
703                                         STACK_OF(X509_NAME) * *root_names) {
704   tsi_result result = TSI_OK;
705   size_t num_roots = 0;
706   X509* root = nullptr;
707   X509_NAME* root_name = nullptr;
708   BIO* pem;
709   GPR_ASSERT(pem_roots_size <= INT_MAX);
710   pem = BIO_new_mem_buf(pem_roots, static_cast<int>(pem_roots_size));
711   if (cert_store == nullptr) return TSI_INVALID_ARGUMENT;
712   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
713   if (root_names != nullptr) {
714     *root_names = sk_X509_NAME_new_null();
715     if (*root_names == nullptr) return TSI_OUT_OF_RESOURCES;
716   }
717 
718   while (true) {
719     root = PEM_read_bio_X509_AUX(pem, nullptr, nullptr, const_cast<char*>(""));
720     if (root == nullptr) {
721       ERR_clear_error();
722       break;  // We're at the end of stream.
723     }
724     if (root_names != nullptr) {
725       root_name = X509_get_subject_name(root);
726       if (root_name == nullptr) {
727         gpr_log(GPR_ERROR, "Could not get name from root certificate.");
728         result = TSI_INVALID_ARGUMENT;
729         break;
730       }
731       root_name = X509_NAME_dup(root_name);
732       if (root_name == nullptr) {
733         result = TSI_OUT_OF_RESOURCES;
734         break;
735       }
736       sk_X509_NAME_push(*root_names, root_name);
737       root_name = nullptr;
738     }
739     ERR_clear_error();
740     if (!X509_STORE_add_cert(cert_store, root)) {
741       unsigned long error = ERR_get_error();
742       if (ERR_GET_LIB(error) != ERR_LIB_X509 ||
743           ERR_GET_REASON(error) != X509_R_CERT_ALREADY_IN_HASH_TABLE) {
744         gpr_log(GPR_ERROR, "Could not add root certificate to ssl context.");
745         result = TSI_INTERNAL_ERROR;
746         break;
747       }
748     }
749     X509_free(root);
750     num_roots++;
751   }
752   if (num_roots == 0) {
753     gpr_log(GPR_ERROR, "Could not load any root certificate.");
754     result = TSI_INVALID_ARGUMENT;
755   }
756 
757   if (result != TSI_OK) {
758     if (root != nullptr) X509_free(root);
759     if (root_names != nullptr) {
760       sk_X509_NAME_pop_free(*root_names, X509_NAME_free);
761       *root_names = nullptr;
762       if (root_name != nullptr) X509_NAME_free(root_name);
763     }
764   }
765   BIO_free(pem);
766   return result;
767 }
768 
ssl_ctx_load_verification_certs(SSL_CTX * context,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_name)769 static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context,
770                                                   const char* pem_roots,
771                                                   size_t pem_roots_size,
772                                                   STACK_OF(X509_NAME) *
773                                                       *root_name) {
774   X509_STORE* cert_store = SSL_CTX_get_cert_store(context);
775   X509_STORE_set_flags(cert_store,
776                        X509_V_FLAG_PARTIAL_CHAIN | X509_V_FLAG_TRUSTED_FIRST);
777   return x509_store_load_certs(cert_store, pem_roots, pem_roots_size,
778                                root_name);
779 }
780 
781 // Populates the SSL context with a private key and a cert chain, and sets the
782 // cipher list and the ephemeral ECDH key.
populate_ssl_context(SSL_CTX * context,const tsi_ssl_pem_key_cert_pair * key_cert_pair,const char * cipher_list)783 static tsi_result populate_ssl_context(
784     SSL_CTX* context, const tsi_ssl_pem_key_cert_pair* key_cert_pair,
785     const char* cipher_list) {
786   tsi_result result = TSI_OK;
787   if (key_cert_pair != nullptr) {
788     if (key_cert_pair->cert_chain != nullptr) {
789       result = ssl_ctx_use_certificate_chain(context, key_cert_pair->cert_chain,
790                                              strlen(key_cert_pair->cert_chain));
791       if (result != TSI_OK) {
792         gpr_log(GPR_ERROR, "Invalid cert chain file.");
793         return result;
794       }
795     }
796     if (key_cert_pair->private_key != nullptr) {
797       result = ssl_ctx_use_private_key(context, key_cert_pair->private_key,
798                                        strlen(key_cert_pair->private_key));
799       if (result != TSI_OK || !SSL_CTX_check_private_key(context)) {
800         gpr_log(GPR_ERROR, "Invalid private key.");
801         return result != TSI_OK ? result : TSI_INVALID_ARGUMENT;
802       }
803     }
804   }
805   if ((cipher_list != nullptr) &&
806       !SSL_CTX_set_cipher_list(context, cipher_list)) {
807     gpr_log(GPR_ERROR, "Invalid cipher list: %s.", cipher_list);
808     return TSI_INVALID_ARGUMENT;
809   }
810   {
811 #if OPENSSL_VERSION_NUMBER < 0x30000000L
812     EC_KEY* ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
813     if (!SSL_CTX_set_tmp_ecdh(context, ecdh)) {
814       gpr_log(GPR_ERROR, "Could not set ephemeral ECDH key.");
815       EC_KEY_free(ecdh);
816       return TSI_INTERNAL_ERROR;
817     }
818     SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE);
819     EC_KEY_free(ecdh);
820 #else
821     if (!SSL_CTX_set1_groups(context, kSslEcCurveNames, 1)) {
822       gpr_log(GPR_ERROR, "Could not set ephemeral ECDH key.");
823       return TSI_INTERNAL_ERROR;
824     }
825     SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE);
826 #endif
827   }
828   return TSI_OK;
829 }
830 
831 // Extracts the CN and the SANs from an X509 cert as a peer object.
tsi_ssl_extract_x509_subject_names_from_pem_cert(const char * pem_cert,tsi_peer * peer)832 tsi_result tsi_ssl_extract_x509_subject_names_from_pem_cert(
833     const char* pem_cert, tsi_peer* peer) {
834   tsi_result result = TSI_OK;
835   X509* cert = nullptr;
836   BIO* pem;
837   pem = BIO_new_mem_buf(pem_cert, static_cast<int>(strlen(pem_cert)));
838   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
839 
840   cert = PEM_read_bio_X509(pem, nullptr, nullptr, const_cast<char*>(""));
841   if (cert == nullptr) {
842     gpr_log(GPR_ERROR, "Invalid certificate");
843     result = TSI_INVALID_ARGUMENT;
844   } else {
845     result = peer_from_x509(cert, 0, peer);
846   }
847   if (cert != nullptr) X509_free(cert);
848   BIO_free(pem);
849   return result;
850 }
851 
852 // Builds the alpn protocol name list according to rfc 7301.
build_alpn_protocol_name_list(const char ** alpn_protocols,uint16_t num_alpn_protocols,unsigned char ** protocol_name_list,size_t * protocol_name_list_length)853 static tsi_result build_alpn_protocol_name_list(
854     const char** alpn_protocols, uint16_t num_alpn_protocols,
855     unsigned char** protocol_name_list, size_t* protocol_name_list_length) {
856   uint16_t i;
857   unsigned char* current;
858   *protocol_name_list = nullptr;
859   *protocol_name_list_length = 0;
860   if (num_alpn_protocols == 0) return TSI_INVALID_ARGUMENT;
861   for (i = 0; i < num_alpn_protocols; i++) {
862     size_t length =
863         alpn_protocols[i] == nullptr ? 0 : strlen(alpn_protocols[i]);
864     if (length == 0 || length > 255) {
865       gpr_log(GPR_ERROR, "Invalid protocol name length: %d.",
866               static_cast<int>(length));
867       return TSI_INVALID_ARGUMENT;
868     }
869     *protocol_name_list_length += length + 1;
870   }
871   *protocol_name_list =
872       static_cast<unsigned char*>(gpr_malloc(*protocol_name_list_length));
873   if (*protocol_name_list == nullptr) return TSI_OUT_OF_RESOURCES;
874   current = *protocol_name_list;
875   for (i = 0; i < num_alpn_protocols; i++) {
876     size_t length = strlen(alpn_protocols[i]);
877     *(current++) = static_cast<uint8_t>(length);  // max checked above.
878     memcpy(current, alpn_protocols[i], length);
879     current += length;
880   }
881   // Safety check.
882   if ((current < *protocol_name_list) ||
883       (static_cast<uintptr_t>(current - *protocol_name_list) !=
884        *protocol_name_list_length)) {
885     return TSI_INTERNAL_ERROR;
886   }
887   return TSI_OK;
888 }
889 
890 // This callback is invoked when the CRL has been verified and will soft-fail
891 // errors in verification depending on certain error types.
verify_cb(int ok,X509_STORE_CTX * ctx)892 static int verify_cb(int ok, X509_STORE_CTX* ctx) {
893   int cert_error = X509_STORE_CTX_get_error(ctx);
894   if (cert_error == X509_V_ERR_UNABLE_TO_GET_CRL) {
895     gpr_log(GPR_INFO,
896             "Certificate verification failed to find relevant CRL file. "
897             "Ignoring error.");
898     return 1;
899   }
900   if (cert_error != 0) {
901     gpr_log(GPR_ERROR, "Certificate verify failed with code %d", cert_error);
902   }
903   return ok;
904 }
905 
906 // The verification callback is used for clients that don't really care about
907 // the server's certificate, but we need to pull it anyway, in case a higher
908 // layer wants to look at it. In this case the verification may fail, but
909 // we don't really care.
NullVerifyCallback(X509_STORE_CTX *,void *)910 static int NullVerifyCallback(X509_STORE_CTX* /*ctx*/, void* /*arg*/) {
911   return 1;
912 }
913 
RootCertExtractCallback(X509_STORE_CTX * ctx,void *)914 static int RootCertExtractCallback(X509_STORE_CTX* ctx, void* /*arg*/) {
915   int ret = 1;
916   // Verification was successful. Get the verified chain from the X509_STORE_CTX
917   // and put the root on the SSL object so that we have access to it when
918   // populating the tsi_peer. On error extracting the root, we return success
919   // anyway and proceed with the connection, to preserve the behavior of an
920   // older version of this code.
921 #if OPENSSL_VERSION_NUMBER >= 0x10100000
922   STACK_OF(X509)* chain = X509_STORE_CTX_get0_chain(ctx);
923 #else
924   STACK_OF(X509)* chain = X509_STORE_CTX_get_chain(ctx);
925 #endif
926   if (chain == nullptr) {
927     return ret;
928   }
929 
930   // The root cert is the last in the chain
931   size_t chain_length = sk_X509_num(chain);
932   if (chain_length == 0) {
933     return ret;
934   }
935   X509* root_cert = sk_X509_value(chain, chain_length - 1);
936   if (root_cert == nullptr) {
937     return ret;
938   }
939 
940   ERR_clear_error();
941   int ssl_index = SSL_get_ex_data_X509_STORE_CTX_idx();
942   if (ssl_index < 0) {
943     char err_str[256];
944     ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
945     gpr_log(GPR_ERROR,
946             "error getting the SSL index from the X509_STORE_CTX: %s", err_str);
947     return ret;
948   }
949   SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, ssl_index));
950   if (ssl == nullptr) {
951     return ret;
952   }
953 
954   // Free the old root and save the new one. There should not be an old root,
955   // but if renegotiation is not disabled (required by RFC 9113, Section
956   // 9.2.1), it is possible that this callback run multiple times for a single
957   // connection. gRPC does not always disable renegotiation. See
958   // https://github.com/grpc/grpc/issues/35368
959   X509_free(static_cast<X509*>(
960       SSL_get_ex_data(ssl, g_ssl_ex_verified_root_cert_index)));
961   int success =
962       SSL_set_ex_data(ssl, g_ssl_ex_verified_root_cert_index, root_cert);
963   if (success == 0) {
964     gpr_log(GPR_INFO, "Could not set verified root cert in SSL's ex_data");
965   } else {
966 #if OPENSSL_VERSION_NUMBER >= 0x10100000L
967     X509_up_ref(root_cert);
968 #else
969     CRYPTO_add(&root_cert->references, 1, CRYPTO_LOCK_X509);
970 #endif
971   }
972   return ret;
973 }
974 
GetCrlProvider(X509_STORE_CTX * ctx)975 static grpc_core::experimental::CrlProvider* GetCrlProvider(
976     X509_STORE_CTX* ctx) {
977   ERR_clear_error();
978   int ssl_index = SSL_get_ex_data_X509_STORE_CTX_idx();
979   if (ssl_index < 0) {
980     char err_str[256];
981     ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
982     gpr_log(GPR_INFO,
983             "error getting the SSL index from the X509_STORE_CTX while looking "
984             "up Crl: %s",
985             err_str);
986     return nullptr;
987   }
988   SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, ssl_index));
989   if (ssl == nullptr) {
990     gpr_log(GPR_INFO,
991             "error while fetching from CrlProvider. SSL object is null");
992     return nullptr;
993   }
994   SSL_CTX* ssl_ctx = SSL_get_SSL_CTX(ssl);
995   auto* provider = static_cast<grpc_core::experimental::CrlProvider*>(
996       SSL_CTX_get_ex_data(ssl_ctx, g_ssl_ctx_ex_crl_provider_index));
997   return provider;
998 }
999 
1000 // If a CRL is returned, the caller is the owner of the CRL and must make sure
1001 // it is freed.
GetCrlFromProvider(grpc_core::experimental::CrlProvider * provider,X509 * cert)1002 static absl::StatusOr<X509_CRL*> GetCrlFromProvider(
1003     grpc_core::experimental::CrlProvider* provider, X509* cert) {
1004   if (provider == nullptr) {
1005     return absl::InvalidArgumentError("CrlProvider is null.");
1006   }
1007   absl::StatusOr<std::string> issuer_name = grpc_core::IssuerFromCert(cert);
1008   if (!issuer_name.ok()) {
1009     gpr_log(GPR_INFO, "Could not get certificate issuer name");
1010     return absl::InvalidArgumentError(issuer_name.status().message());
1011   }
1012   absl::StatusOr<std::string> akid = grpc_core::AkidFromCertificate(cert);
1013   std::string akid_to_use;
1014   if (!akid.ok()) {
1015     gpr_log(GPR_INFO, "Could not get certificate authority key identifier.");
1016   } else {
1017     akid_to_use = *akid;
1018   }
1019 
1020   grpc_core::experimental::CertificateInfoImpl cert_impl(*issuer_name,
1021                                                          akid_to_use);
1022   std::shared_ptr<grpc_core::experimental::Crl> internal_crl =
1023       provider->GetCrl(cert_impl);
1024   // There wasn't a CRL found in the provider. Returning 0 will end up causing
1025   // OpenSSL to return X509_V_ERR_UNABLE_TO_GET_CRL. We then catch that error
1026   // and behave how we want for a missing CRL.
1027   // It is important to treat missing CRLs and empty CRLs differently.
1028   if (internal_crl == nullptr) {
1029     return absl::NotFoundError("Could not find Crl related to certificate.");
1030   }
1031   X509_CRL* crl =
1032       std::static_pointer_cast<grpc_core::experimental::CrlImpl>(internal_crl)
1033           ->crl();
1034 
1035   return X509_CRL_dup(crl);
1036 }
1037 
1038 // Perform the validation checks in RFC5280 6.3.3 to ensure the given CRL is
1039 // valid
1040 // returns true if the Crl is valid, false otherwise
ValidateCrl(X509 * cert,X509 * issuer,X509_CRL * crl)1041 static bool ValidateCrl(X509* cert, X509* issuer, X509_CRL* crl) {
1042   bool valid = true;
1043   // RFC5280 6.3.3
1044   // 6.3.3a we do not support distribution points
1045   // 6.3.3b verify issuer and scope
1046   valid = grpc_core::VerifyCrlCertIssuerNamesMatch(crl, cert);
1047   if (!valid) {
1048     gpr_log(GPR_DEBUG, "CRL and cert issuer names mismatched.");
1049     return valid;
1050   }
1051   valid = grpc_core::HasCrlSignBit(issuer);
1052   if (!valid) {
1053     gpr_log(GPR_DEBUG, "CRL issuer not allowed to sign CRLs.");
1054     return valid;
1055   }
1056   // 6.3.3c Not supporting deltas
1057   // 6.3.3d Not supporting reasons masks
1058   // 6.3.3e Not supporting reasons masks
1059   // 6.3.3f We only support direct CRLs so these paths are by definition the
1060   // same.
1061   // 6.3.3g Verify CRL Signature
1062   valid = grpc_core::VerifyCrlSignature(crl, issuer);
1063   if (!valid) {
1064     gpr_log(GPR_DEBUG, "Crl signature check failed.");
1065   }
1066   return valid;
1067 }
1068 
1069 // Check if a given certificate is revoked
1070 // Returns 1 if the certificate is not revoked, 0 if the certificate is revoked
CheckCertRevocation(grpc_core::experimental::CrlProvider * provider,X509 * cert,X509 * issuer)1071 static int CheckCertRevocation(grpc_core::experimental::CrlProvider* provider,
1072                                X509* cert, X509* issuer) {
1073   auto crl = GetCrlFromProvider(provider, cert);
1074   // Not finding a CRL is a specific behavior. Per RFC5280, not having a CRL to
1075   // check for a given certificate means that we cannot know for certain if the
1076   // status is Revoked or Unrevoked and instead is Undetermined. How a user
1077   // handles an Undetermined CRL is up to them. We use absl::IsNotFound as an
1078   // analogue for not finding the Crl from the provider, thus the certificate in
1079   // question is Undetermined.
1080   if (absl::IsNotFound(crl.status())) {
1081     // TODO(gtcooke94) knob for undetermined being revoked or unrevoked. By
1082     // default, unrevoked.
1083     return 1;
1084   } else if (!crl.ok()) {
1085     // This is an unexpected error, return false
1086     return 0;
1087   }
1088   // Validate the crl
1089   // RFC5280 6.3.3(a-i)
1090   if (!ValidateCrl(cert, issuer, *crl)) {
1091     X509_CRL_free(*crl);
1092     return 0;
1093   }
1094 
1095   // RFC5280 6.3.3j Actually check revocation
1096   // Look for serial number of certificate in CRL  X509_REVOKED* rev =
1097   // nullptr;
1098   X509_REVOKED* rev;
1099   if (X509_CRL_get0_by_cert(*crl, &rev, cert)) {
1100     // cert is revoked
1101     X509_CRL_free(*crl);
1102     return 0;
1103   }
1104   // The certificate is not revoked
1105   // RFC5280k - Not supported
1106   // RFC5280l - Not supported
1107   X509_CRL_free(*crl);
1108   return 1;
1109 }
1110 
1111 // Checks each certificate in the chain for revocation
1112 // returns 0 if any cert in the chain is revoked, 1 otherwise.
CheckChainRevocation(X509_STORE_CTX * ctx,grpc_core::experimental::CrlProvider * provider)1113 static int CheckChainRevocation(
1114     X509_STORE_CTX* ctx, grpc_core::experimental::CrlProvider* provider) {
1115 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1116   STACK_OF(X509)* chain = X509_STORE_CTX_get0_chain(ctx);
1117 #else
1118   STACK_OF(X509)* chain = X509_STORE_CTX_get_chain(ctx);
1119 #endif
1120   if (chain == nullptr) {
1121     return 0;
1122   }
1123   // BoringSSL returns a size_t (unsigned), while OpenSSL returns an int
1124   // (signed). In OpenSSL, a -1 can indicate a problem. By forcing it into a
1125   // size_t, a -1 return will result in the chain_length being a very large
1126   // number, so it will still fail this check because that very large number
1127   // will be >= kMaxChainLength
1128   size_t chain_length = sk_X509_num(chain);
1129   if (chain_length > kMaxChainLength || chain_length == 0) {
1130     return 0;
1131   }
1132   // Loop to < chain_length - 1 because the last cert is the trust anchor/root
1133   // which cannot be revoked
1134   for (size_t i = 0; i < chain_length - 1; i++) {
1135     X509* cert = sk_X509_value(chain, i);
1136     X509* issuer = sk_X509_value(chain, i + 1);
1137     int ret = CheckCertRevocation(provider, cert, issuer);
1138     if (ret != 1) {
1139       return ret;
1140     }
1141   }
1142   return 1;
1143 }
1144 
1145 // The custom verification function to set in OpenSSL using
1146 // X509_set_cert_verify_callback. This calls the standard OpenSSL procedure
1147 // (X509_verify_cert), then also extracts the root certificate in the built
1148 // chain and does revocation checks when a user has configured CrlProviders.
1149 // returns 1 on success, indicating a trusted chain to a root of trust was
1150 // found, 0 if a trusted chain could not be built.
CustomVerificationFunction(X509_STORE_CTX * ctx,void * arg)1151 static int CustomVerificationFunction(X509_STORE_CTX* ctx, void* arg) {
1152   int ret = X509_verify_cert(ctx);
1153   if (ret <= 0) {
1154     gpr_log(GPR_DEBUG, "Failed to verify cert chain.");
1155     // Verification failed. We shouldn't expect to have a verified chain, so
1156     // there is no need to attempt to extract the root cert from it, check for
1157     // revocation, or check anything else.
1158     return ret;
1159   }
1160   grpc_core::experimental::CrlProvider* provider = GetCrlProvider(ctx);
1161   if (provider != nullptr) {
1162     ret = CheckChainRevocation(ctx, provider);
1163     if (ret <= 0) {
1164       gpr_log(GPR_DEBUG, "The chain failed revocation checks.");
1165       return ret;
1166     }
1167   }
1168   return RootCertExtractCallback(ctx, arg);
1169 }
1170 
1171 // Sets the min and max TLS version of |ssl_context| to |min_tls_version| and
1172 // |max_tls_version|, respectively. Calling this method is a no-op when using
1173 // OpenSSL versions < 1.1.
tsi_set_min_and_max_tls_versions(SSL_CTX * ssl_context,tsi_tls_version min_tls_version,tsi_tls_version max_tls_version)1174 static tsi_result tsi_set_min_and_max_tls_versions(
1175     SSL_CTX* ssl_context, tsi_tls_version min_tls_version,
1176     tsi_tls_version max_tls_version) {
1177   if (ssl_context == nullptr) {
1178     gpr_log(GPR_INFO,
1179             "Invalid nullptr argument to |tsi_set_min_and_max_tls_versions|.");
1180     return TSI_INVALID_ARGUMENT;
1181   }
1182 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1183   // Set the min TLS version of the SSL context if using OpenSSL version
1184   // >= 1.1.0. This OpenSSL version is required because the
1185   // |SSL_CTX_set_min_proto_version| and |SSL_CTX_set_max_proto_version| APIs
1186   // only exist in this version range.
1187   switch (min_tls_version) {
1188     case tsi_tls_version::TSI_TLS1_2:
1189       SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION);
1190       break;
1191 #if defined(TLS1_3_VERSION)
1192     // If the library does not support TLS 1.3 and the caller requests a
1193     // minimum of TLS 1.3, then return an error because the caller's request
1194     // cannot be satisfied.
1195     case tsi_tls_version::TSI_TLS1_3:
1196       SSL_CTX_set_min_proto_version(ssl_context, TLS1_3_VERSION);
1197       break;
1198 #endif
1199     default:
1200       gpr_log(GPR_INFO, "TLS version is not supported.");
1201       return TSI_FAILED_PRECONDITION;
1202   }
1203 
1204   // Set the max TLS version of the SSL context.
1205   switch (max_tls_version) {
1206     case tsi_tls_version::TSI_TLS1_2:
1207       SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
1208       break;
1209     case tsi_tls_version::TSI_TLS1_3:
1210 #if defined(TLS1_3_VERSION)
1211       SSL_CTX_set_max_proto_version(ssl_context, TLS1_3_VERSION);
1212 #else
1213       // If the library does not support TLS 1.3, then set the max TLS version
1214       // to TLS 1.2 instead.
1215       SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
1216 #endif
1217       break;
1218     default:
1219       gpr_log(GPR_INFO, "TLS version is not supported.");
1220       return TSI_FAILED_PRECONDITION;
1221   }
1222 #endif
1223   return TSI_OK;
1224 }
1225 
1226 // --- tsi_ssl_root_certs_store methods implementation. ---
1227 
tsi_ssl_root_certs_store_create(const char * pem_roots)1228 tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create(
1229     const char* pem_roots) {
1230   if (pem_roots == nullptr) {
1231     gpr_log(GPR_ERROR, "The root certificates are empty.");
1232     return nullptr;
1233   }
1234   tsi_ssl_root_certs_store* root_store = static_cast<tsi_ssl_root_certs_store*>(
1235       gpr_zalloc(sizeof(tsi_ssl_root_certs_store)));
1236   if (root_store == nullptr) {
1237     gpr_log(GPR_ERROR, "Could not allocate buffer for ssl_root_certs_store.");
1238     return nullptr;
1239   }
1240   root_store->store = X509_STORE_new();
1241   if (root_store->store == nullptr) {
1242     gpr_log(GPR_ERROR, "Could not allocate buffer for X509_STORE.");
1243     gpr_free(root_store);
1244     return nullptr;
1245   }
1246   tsi_result result = x509_store_load_certs(root_store->store, pem_roots,
1247                                             strlen(pem_roots), nullptr);
1248   if (result != TSI_OK) {
1249     gpr_log(GPR_ERROR, "Could not load root certificates.");
1250     X509_STORE_free(root_store->store);
1251     gpr_free(root_store);
1252     return nullptr;
1253   }
1254 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1255   X509_VERIFY_PARAM* param = X509_STORE_get0_param(root_store->store);
1256 #else
1257   X509_VERIFY_PARAM* param = root_store->store->param;
1258 #endif
1259   X509_VERIFY_PARAM_set_depth(param, kMaxChainLength);
1260   return root_store;
1261 }
1262 
tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store * self)1263 void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self) {
1264   if (self == nullptr) return;
1265   X509_STORE_free(self->store);
1266   gpr_free(self);
1267 }
1268 
1269 // --- tsi_ssl_session_cache methods implementation. ---
1270 
tsi_ssl_session_cache_create_lru(size_t capacity)1271 tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) {
1272   // Pointer will be dereferenced by unref call.
1273   return tsi::SslSessionLRUCache::Create(capacity).release()->c_ptr();
1274 }
1275 
tsi_ssl_session_cache_ref(tsi_ssl_session_cache * cache)1276 void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache) {
1277   // Pointer will be dereferenced by unref call.
1278   tsi::SslSessionLRUCache::FromC(cache)->Ref().release();
1279 }
1280 
tsi_ssl_session_cache_unref(tsi_ssl_session_cache * cache)1281 void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache) {
1282   tsi::SslSessionLRUCache::FromC(cache)->Unref();
1283 }
1284 
1285 // --- tsi_frame_protector methods implementation. ---
1286 
ssl_protector_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)1287 static tsi_result ssl_protector_protect(tsi_frame_protector* self,
1288                                         const unsigned char* unprotected_bytes,
1289                                         size_t* unprotected_bytes_size,
1290                                         unsigned char* protected_output_frames,
1291                                         size_t* protected_output_frames_size) {
1292   tsi_ssl_frame_protector* impl =
1293       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1294 
1295   return grpc_core::SslProtectorProtect(
1296       unprotected_bytes, impl->buffer_size, impl->buffer_offset, impl->buffer,
1297       impl->ssl, impl->network_io, unprotected_bytes_size,
1298       protected_output_frames, protected_output_frames_size);
1299 }
1300 
ssl_protector_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)1301 static tsi_result ssl_protector_protect_flush(
1302     tsi_frame_protector* self, unsigned char* protected_output_frames,
1303     size_t* protected_output_frames_size, size_t* still_pending_size) {
1304   tsi_ssl_frame_protector* impl =
1305       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1306   return grpc_core::SslProtectorProtectFlush(
1307       impl->buffer_offset, impl->buffer, impl->ssl, impl->network_io,
1308       protected_output_frames, protected_output_frames_size,
1309       still_pending_size);
1310 }
1311 
ssl_protector_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)1312 static tsi_result ssl_protector_unprotect(
1313     tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
1314     size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
1315     size_t* unprotected_bytes_size) {
1316   tsi_ssl_frame_protector* impl =
1317       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1318   return grpc_core::SslProtectorUnprotect(
1319       protected_frames_bytes, impl->ssl, impl->network_io,
1320       protected_frames_bytes_size, unprotected_bytes, unprotected_bytes_size);
1321 }
1322 
ssl_protector_destroy(tsi_frame_protector * self)1323 static void ssl_protector_destroy(tsi_frame_protector* self) {
1324   tsi_ssl_frame_protector* impl =
1325       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1326   if (impl->buffer != nullptr) gpr_free(impl->buffer);
1327   if (impl->ssl != nullptr) SSL_free(impl->ssl);
1328   if (impl->network_io != nullptr) BIO_free(impl->network_io);
1329   gpr_free(self);
1330 }
1331 
1332 static const tsi_frame_protector_vtable frame_protector_vtable = {
1333     ssl_protector_protect,
1334     ssl_protector_protect_flush,
1335     ssl_protector_unprotect,
1336     ssl_protector_destroy,
1337 };
1338 
1339 // --- tsi_server_handshaker_factory methods implementation. ---
1340 
tsi_ssl_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1341 static void tsi_ssl_handshaker_factory_destroy(
1342     tsi_ssl_handshaker_factory* factory) {
1343   if (factory == nullptr) return;
1344 
1345   if (factory->vtable != nullptr && factory->vtable->destroy != nullptr) {
1346     factory->vtable->destroy(factory);
1347   }
1348   // Note, we don't free(self) here because this object is always directly
1349   // embedded in another object. If tsi_ssl_handshaker_factory_init allocates
1350   // any memory, it should be free'd here.
1351 }
1352 
tsi_ssl_handshaker_factory_ref(tsi_ssl_handshaker_factory * factory)1353 static tsi_ssl_handshaker_factory* tsi_ssl_handshaker_factory_ref(
1354     tsi_ssl_handshaker_factory* factory) {
1355   if (factory == nullptr) return nullptr;
1356   gpr_refn(&factory->refcount, 1);
1357   return factory;
1358 }
1359 
tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory * factory)1360 static void tsi_ssl_handshaker_factory_unref(
1361     tsi_ssl_handshaker_factory* factory) {
1362   if (factory == nullptr) return;
1363 
1364   if (gpr_unref(&factory->refcount)) {
1365     tsi_ssl_handshaker_factory_destroy(factory);
1366   }
1367 }
1368 
1369 static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {nullptr};
1370 
1371 // Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for
1372 // allocating memory for the factory.
tsi_ssl_handshaker_factory_init(tsi_ssl_handshaker_factory * factory)1373 static void tsi_ssl_handshaker_factory_init(
1374     tsi_ssl_handshaker_factory* factory) {
1375   GPR_ASSERT(factory != nullptr);
1376 
1377   factory->vtable = &handshaker_factory_vtable;
1378   gpr_ref_init(&factory->refcount, 1);
1379 }
1380 
1381 // Gets the X509 cert chain in PEM format as a tsi_peer_property.
tsi_ssl_get_cert_chain_contents(STACK_OF (X509)* peer_chain,tsi_peer_property * property)1382 tsi_result tsi_ssl_get_cert_chain_contents(STACK_OF(X509) * peer_chain,
1383                                            tsi_peer_property* property) {
1384   BIO* bio = BIO_new(BIO_s_mem());
1385   const auto peer_chain_len = sk_X509_num(peer_chain);
1386   for (auto i = decltype(peer_chain_len){0}; i < peer_chain_len; i++) {
1387     if (!PEM_write_bio_X509(bio, sk_X509_value(peer_chain, i))) {
1388       BIO_free(bio);
1389       return TSI_INTERNAL_ERROR;
1390     }
1391   }
1392   char* contents;
1393   long len = BIO_get_mem_data(bio, &contents);
1394   if (len <= 0) {
1395     BIO_free(bio);
1396     return TSI_INTERNAL_ERROR;
1397   }
1398   tsi_result result = tsi_construct_string_peer_property(
1399       TSI_X509_PEM_CERT_CHAIN_PROPERTY, contents, static_cast<size_t>(len),
1400       property);
1401   BIO_free(bio);
1402   return result;
1403 }
1404 
1405 // --- tsi_handshaker_result methods implementation. ---
ssl_handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)1406 static tsi_result ssl_handshaker_result_extract_peer(
1407     const tsi_handshaker_result* self, tsi_peer* peer) {
1408   tsi_result result = TSI_OK;
1409   const unsigned char* alpn_selected = nullptr;
1410   unsigned int alpn_selected_len;
1411   const tsi_ssl_handshaker_result* impl =
1412       reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1413   X509* peer_cert = SSL_get_peer_certificate(impl->ssl);
1414   if (peer_cert != nullptr) {
1415     result = peer_from_x509(peer_cert, 1, peer);
1416     X509_free(peer_cert);
1417     if (result != TSI_OK) return result;
1418   }
1419 #if TSI_OPENSSL_ALPN_SUPPORT
1420   SSL_get0_alpn_selected(impl->ssl, &alpn_selected, &alpn_selected_len);
1421 #endif  // TSI_OPENSSL_ALPN_SUPPORT
1422   if (alpn_selected == nullptr) {
1423     // Try npn.
1424     SSL_get0_next_proto_negotiated(impl->ssl, &alpn_selected,
1425                                    &alpn_selected_len);
1426   }
1427   // When called on the client side, the stack also contains the
1428   // peer's certificate; When called on the server side,
1429   // the peer's certificate is not present in the stack
1430   STACK_OF(X509)* peer_chain = SSL_get_peer_cert_chain(impl->ssl);
1431 
1432   X509* verified_root_cert = static_cast<X509*>(
1433       SSL_get_ex_data(impl->ssl, g_ssl_ex_verified_root_cert_index));
1434   // 1 is for session reused property.
1435   size_t new_property_count = peer->property_count + 3;
1436   if (alpn_selected != nullptr) new_property_count++;
1437   if (peer_chain != nullptr) new_property_count++;
1438   if (verified_root_cert != nullptr) new_property_count++;
1439   tsi_peer_property* new_properties = static_cast<tsi_peer_property*>(
1440       gpr_zalloc(sizeof(*new_properties) * new_property_count));
1441   for (size_t i = 0; i < peer->property_count; i++) {
1442     new_properties[i] = peer->properties[i];
1443   }
1444   if (peer->properties != nullptr) gpr_free(peer->properties);
1445   peer->properties = new_properties;
1446   // Add peer chain if available
1447   if (peer_chain != nullptr) {
1448     result = tsi_ssl_get_cert_chain_contents(
1449         peer_chain, &peer->properties[peer->property_count]);
1450     if (result == TSI_OK) peer->property_count++;
1451   }
1452   if (alpn_selected != nullptr) {
1453     result = tsi_construct_string_peer_property(
1454         TSI_SSL_ALPN_SELECTED_PROTOCOL,
1455         reinterpret_cast<const char*>(alpn_selected), alpn_selected_len,
1456         &peer->properties[peer->property_count]);
1457     if (result != TSI_OK) return result;
1458     peer->property_count++;
1459   }
1460   // Add security_level peer property.
1461   result = tsi_construct_string_peer_property_from_cstring(
1462       TSI_SECURITY_LEVEL_PEER_PROPERTY,
1463       tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY),
1464       &peer->properties[peer->property_count]);
1465   if (result != TSI_OK) return result;
1466   peer->property_count++;
1467 
1468   const char* session_reused = SSL_session_reused(impl->ssl) ? "true" : "false";
1469   result = tsi_construct_string_peer_property_from_cstring(
1470       TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused,
1471       &peer->properties[peer->property_count]);
1472   if (result != TSI_OK) return result;
1473   peer->property_count++;
1474 
1475   if (verified_root_cert != nullptr) {
1476     result = peer_property_from_x509_subject(
1477         verified_root_cert, &peer->properties[peer->property_count], true);
1478     if (result != TSI_OK) {
1479       gpr_log(GPR_DEBUG,
1480               "Problem extracting subject from verified_root_cert. result: %d",
1481               static_cast<int>(result));
1482     }
1483     peer->property_count++;
1484   }
1485 
1486   return result;
1487 }
1488 
ssl_handshaker_result_get_frame_protector_type(const tsi_handshaker_result *,tsi_frame_protector_type * frame_protector_type)1489 static tsi_result ssl_handshaker_result_get_frame_protector_type(
1490     const tsi_handshaker_result* /*self*/,
1491     tsi_frame_protector_type* frame_protector_type) {
1492   *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL;
1493   return TSI_OK;
1494 }
1495 
ssl_handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)1496 static tsi_result ssl_handshaker_result_create_frame_protector(
1497     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
1498     tsi_frame_protector** protector) {
1499   size_t actual_max_output_protected_frame_size =
1500       TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1501   tsi_ssl_handshaker_result* impl =
1502       reinterpret_cast<tsi_ssl_handshaker_result*>(
1503           const_cast<tsi_handshaker_result*>(self));
1504   tsi_ssl_frame_protector* protector_impl =
1505       static_cast<tsi_ssl_frame_protector*>(
1506           gpr_zalloc(sizeof(*protector_impl)));
1507 
1508   if (max_output_protected_frame_size != nullptr) {
1509     if (*max_output_protected_frame_size >
1510         TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND) {
1511       *max_output_protected_frame_size =
1512           TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1513     } else if (*max_output_protected_frame_size <
1514                TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND) {
1515       *max_output_protected_frame_size =
1516           TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND;
1517     }
1518     actual_max_output_protected_frame_size = *max_output_protected_frame_size;
1519   }
1520   protector_impl->buffer_size =
1521       actual_max_output_protected_frame_size - TSI_SSL_MAX_PROTECTION_OVERHEAD;
1522   protector_impl->buffer =
1523       static_cast<unsigned char*>(gpr_malloc(protector_impl->buffer_size));
1524   if (protector_impl->buffer == nullptr) {
1525     gpr_log(GPR_ERROR,
1526             "Could not allocated buffer for tsi_ssl_frame_protector.");
1527     gpr_free(protector_impl);
1528     return TSI_INTERNAL_ERROR;
1529   }
1530 
1531   // Transfer ownership of ssl and network_io to the frame protector.
1532   protector_impl->ssl = impl->ssl;
1533   impl->ssl = nullptr;
1534   protector_impl->network_io = impl->network_io;
1535   impl->network_io = nullptr;
1536   protector_impl->base.vtable = &frame_protector_vtable;
1537   *protector = &protector_impl->base;
1538   return TSI_OK;
1539 }
1540 
ssl_handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)1541 static tsi_result ssl_handshaker_result_get_unused_bytes(
1542     const tsi_handshaker_result* self, const unsigned char** bytes,
1543     size_t* bytes_size) {
1544   const tsi_ssl_handshaker_result* impl =
1545       reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1546   *bytes_size = impl->unused_bytes_size;
1547   *bytes = impl->unused_bytes;
1548   return TSI_OK;
1549 }
1550 
ssl_handshaker_result_destroy(tsi_handshaker_result * self)1551 static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) {
1552   tsi_ssl_handshaker_result* impl =
1553       reinterpret_cast<tsi_ssl_handshaker_result*>(self);
1554   SSL_free(impl->ssl);
1555   BIO_free(impl->network_io);
1556   gpr_free(impl->unused_bytes);
1557   gpr_free(impl);
1558 }
1559 
1560 static const tsi_handshaker_result_vtable handshaker_result_vtable = {
1561     ssl_handshaker_result_extract_peer,
1562     ssl_handshaker_result_get_frame_protector_type,
1563     nullptr,  // create_zero_copy_grpc_protector
1564     ssl_handshaker_result_create_frame_protector,
1565     ssl_handshaker_result_get_unused_bytes,
1566     ssl_handshaker_result_destroy,
1567 };
1568 
ssl_handshaker_result_create(tsi_ssl_handshaker * handshaker,unsigned char * unused_bytes,size_t unused_bytes_size,tsi_handshaker_result ** handshaker_result,std::string * error)1569 static tsi_result ssl_handshaker_result_create(
1570     tsi_ssl_handshaker* handshaker, unsigned char* unused_bytes,
1571     size_t unused_bytes_size, tsi_handshaker_result** handshaker_result,
1572     std::string* error) {
1573   if (handshaker == nullptr || handshaker_result == nullptr ||
1574       (unused_bytes_size > 0 && unused_bytes == nullptr)) {
1575     if (error != nullptr) *error = "invalid argument";
1576     return TSI_INVALID_ARGUMENT;
1577   }
1578   tsi_ssl_handshaker_result* result =
1579       grpc_core::Zalloc<tsi_ssl_handshaker_result>();
1580   result->base.vtable = &handshaker_result_vtable;
1581   // Transfer ownership of ssl and network_io to the handshaker result.
1582   result->ssl = handshaker->ssl;
1583   handshaker->ssl = nullptr;
1584   result->network_io = handshaker->network_io;
1585   handshaker->network_io = nullptr;
1586   // Transfer ownership of |unused_bytes| to the handshaker result.
1587   result->unused_bytes = unused_bytes;
1588   result->unused_bytes_size = unused_bytes_size;
1589   *handshaker_result = &result->base;
1590   return TSI_OK;
1591 }
1592 
1593 // --- tsi_handshaker methods implementation. ---
1594 
ssl_handshaker_get_bytes_to_send_to_peer(tsi_ssl_handshaker * impl,unsigned char * bytes,size_t * bytes_size,std::string * error)1595 static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
1596     tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size,
1597     std::string* error) {
1598   int bytes_read_from_ssl = 0;
1599   if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
1600     if (error != nullptr) *error = "invalid argument";
1601     return TSI_INVALID_ARGUMENT;
1602   }
1603   GPR_ASSERT(*bytes_size <= INT_MAX);
1604   bytes_read_from_ssl =
1605       BIO_read(impl->network_io, bytes, static_cast<int>(*bytes_size));
1606   if (bytes_read_from_ssl < 0) {
1607     *bytes_size = 0;
1608     if (!BIO_should_retry(impl->network_io)) {
1609       if (error != nullptr) *error = "error reading from BIO";
1610       impl->result = TSI_INTERNAL_ERROR;
1611       return impl->result;
1612     } else {
1613       return TSI_OK;
1614     }
1615   }
1616   *bytes_size = static_cast<size_t>(bytes_read_from_ssl);
1617   return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA;
1618 }
1619 
ssl_handshaker_get_result(tsi_ssl_handshaker * impl)1620 static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) {
1621   if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) &&
1622       SSL_is_init_finished(impl->ssl)) {
1623     impl->result = TSI_OK;
1624   }
1625   return impl->result;
1626 }
1627 
ssl_handshaker_do_handshake(tsi_ssl_handshaker * impl,std::string * error)1628 static tsi_result ssl_handshaker_do_handshake(tsi_ssl_handshaker* impl,
1629                                               std::string* error) {
1630   if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) {
1631     impl->result = TSI_OK;
1632     return impl->result;
1633   } else {
1634     ERR_clear_error();
1635     // Get ready to get some bytes from SSL.
1636     int ssl_result = SSL_do_handshake(impl->ssl);
1637     ssl_result = SSL_get_error(impl->ssl, ssl_result);
1638     switch (ssl_result) {
1639       case SSL_ERROR_WANT_READ:
1640         if (BIO_pending(impl->network_io) == 0) {
1641           // We need more data.
1642           return TSI_INCOMPLETE_DATA;
1643         } else {
1644           return TSI_OK;
1645         }
1646       case SSL_ERROR_NONE:
1647         return TSI_OK;
1648       case SSL_ERROR_WANT_WRITE:
1649         return TSI_DRAIN_BUFFER;
1650       default: {
1651         char err_str[256];
1652         ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
1653         gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.",
1654                 grpc_core::SslErrorString(ssl_result), err_str);
1655         if (error != nullptr) {
1656           *error = absl::StrCat(grpc_core::SslErrorString(ssl_result), ": ",
1657                                 err_str);
1658         }
1659         impl->result = TSI_PROTOCOL_FAILURE;
1660         return impl->result;
1661       }
1662     }
1663   }
1664 }
1665 
ssl_handshaker_process_bytes_from_peer(tsi_ssl_handshaker * impl,const unsigned char * bytes,size_t * bytes_size,std::string * error)1666 static tsi_result ssl_handshaker_process_bytes_from_peer(
1667     tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size,
1668     std::string* error) {
1669   int bytes_written_into_ssl_size = 0;
1670   if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
1671     if (error != nullptr) *error = "invalid argument";
1672     return TSI_INVALID_ARGUMENT;
1673   }
1674   GPR_ASSERT(*bytes_size <= INT_MAX);
1675   bytes_written_into_ssl_size =
1676       BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size));
1677   if (bytes_written_into_ssl_size < 0) {
1678     gpr_log(GPR_ERROR, "Could not write to memory BIO.");
1679     if (error != nullptr) *error = "could not write to memory BIO";
1680     impl->result = TSI_INTERNAL_ERROR;
1681     return impl->result;
1682   }
1683   *bytes_size = static_cast<size_t>(bytes_written_into_ssl_size);
1684   return ssl_handshaker_do_handshake(impl, error);
1685 }
1686 
ssl_handshaker_destroy(tsi_handshaker * self)1687 static void ssl_handshaker_destroy(tsi_handshaker* self) {
1688   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1689   SSL_free(impl->ssl);
1690   BIO_free(impl->network_io);
1691   gpr_free(impl->outgoing_bytes_buffer);
1692   tsi_ssl_handshaker_factory_unref(impl->factory_ref);
1693   gpr_free(impl);
1694 }
1695 
1696 // Removes the bytes remaining in |impl->SSL|'s read BIO and writes them to
1697 // |bytes_remaining|.
ssl_bytes_remaining(tsi_ssl_handshaker * impl,unsigned char ** bytes_remaining,size_t * bytes_remaining_size,std::string * error)1698 static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl,
1699                                       unsigned char** bytes_remaining,
1700                                       size_t* bytes_remaining_size,
1701                                       std::string* error) {
1702   if (impl == nullptr || bytes_remaining == nullptr ||
1703       bytes_remaining_size == nullptr) {
1704     if (error != nullptr) *error = "invalid argument";
1705     return TSI_INVALID_ARGUMENT;
1706   }
1707   // Atempt to read all of the bytes in SSL's read BIO. These bytes should
1708   // contain application data records that were appended to a handshake record
1709   // containing the ClientFinished or ServerFinished message.
1710   size_t bytes_in_ssl = BIO_pending(SSL_get_rbio(impl->ssl));
1711   if (bytes_in_ssl == 0) return TSI_OK;
1712   *bytes_remaining = static_cast<uint8_t*>(gpr_malloc(bytes_in_ssl));
1713   int bytes_read = BIO_read(SSL_get_rbio(impl->ssl), *bytes_remaining,
1714                             static_cast<int>(bytes_in_ssl));
1715   // If an unexpected number of bytes were read, return an error status and
1716   // free all of the bytes that were read.
1717   if (bytes_read < 0 || static_cast<size_t>(bytes_read) != bytes_in_ssl) {
1718     gpr_log(GPR_ERROR,
1719             "Failed to read the expected number of bytes from SSL object.");
1720     gpr_free(*bytes_remaining);
1721     *bytes_remaining = nullptr;
1722     if (error != nullptr) {
1723       *error = "Failed to read the expected number of bytes from SSL object.";
1724     }
1725     return TSI_INTERNAL_ERROR;
1726   }
1727   *bytes_remaining_size = static_cast<size_t>(bytes_read);
1728   return TSI_OK;
1729 }
1730 
1731 // Write handshake data received from SSL to an unbound output buffer.
1732 // By doing that, we drain SSL bio buffer used to hold handshake data.
1733 // This API needs to be repeatedly called until all handshake data are
1734 // received from SSL.
ssl_handshaker_write_output_buffer(tsi_handshaker * self,size_t * bytes_written,std::string * error)1735 static tsi_result ssl_handshaker_write_output_buffer(tsi_handshaker* self,
1736                                                      size_t* bytes_written,
1737                                                      std::string* error) {
1738   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1739   tsi_result status = TSI_OK;
1740   size_t offset = *bytes_written;
1741   do {
1742     size_t to_send_size = impl->outgoing_bytes_buffer_size - offset;
1743     status = ssl_handshaker_get_bytes_to_send_to_peer(
1744         impl, impl->outgoing_bytes_buffer + offset, &to_send_size, error);
1745     offset += to_send_size;
1746     if (status == TSI_INCOMPLETE_DATA) {
1747       impl->outgoing_bytes_buffer_size *= 2;
1748       impl->outgoing_bytes_buffer = static_cast<unsigned char*>(gpr_realloc(
1749           impl->outgoing_bytes_buffer, impl->outgoing_bytes_buffer_size));
1750     }
1751   } while (status == TSI_INCOMPLETE_DATA);
1752   *bytes_written = offset;
1753   return status;
1754 }
1755 
ssl_handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** handshaker_result,tsi_handshaker_on_next_done_cb,void *,std::string * error)1756 static tsi_result ssl_handshaker_next(tsi_handshaker* self,
1757                                       const unsigned char* received_bytes,
1758                                       size_t received_bytes_size,
1759                                       const unsigned char** bytes_to_send,
1760                                       size_t* bytes_to_send_size,
1761                                       tsi_handshaker_result** handshaker_result,
1762                                       tsi_handshaker_on_next_done_cb /*cb*/,
1763                                       void* /*user_data*/, std::string* error) {
1764   // Input sanity check.
1765   if ((received_bytes_size > 0 && received_bytes == nullptr) ||
1766       bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
1767       handshaker_result == nullptr) {
1768     if (error != nullptr) *error = "invalid argument";
1769     return TSI_INVALID_ARGUMENT;
1770   }
1771   // If there are received bytes, process them first.
1772   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1773   tsi_result status = TSI_OK;
1774   size_t bytes_written = 0;
1775   if (received_bytes_size > 0) {
1776     unsigned char* remaining_bytes_to_write_to_openssl =
1777         const_cast<unsigned char*>(received_bytes);
1778     size_t remaining_bytes_to_write_to_openssl_size = received_bytes_size;
1779     size_t number_bio_write_attempts = 0;
1780     while (remaining_bytes_to_write_to_openssl_size > 0 &&
1781            (status == TSI_OK || status == TSI_INCOMPLETE_DATA) &&
1782            number_bio_write_attempts < TSI_SSL_MAX_BIO_WRITE_ATTEMPTS) {
1783       ++number_bio_write_attempts;
1784       // Try to write all of the remaining bytes to the BIO.
1785       size_t bytes_written_to_openssl =
1786           remaining_bytes_to_write_to_openssl_size;
1787       status = ssl_handshaker_process_bytes_from_peer(
1788           impl, remaining_bytes_to_write_to_openssl, &bytes_written_to_openssl,
1789           error);
1790       // As long as the BIO is full, drive the SSL handshake to consume bytes
1791       // from the BIO. If the SSL handshake returns any bytes, write them to
1792       // the peer.
1793       while (status == TSI_DRAIN_BUFFER) {
1794         status =
1795             ssl_handshaker_write_output_buffer(self, &bytes_written, error);
1796         if (status != TSI_OK) return status;
1797         status = ssl_handshaker_do_handshake(impl, error);
1798       }
1799       // Move the pointer to the first byte not yet successfully written to
1800       // the BIO.
1801       remaining_bytes_to_write_to_openssl_size -= bytes_written_to_openssl;
1802       remaining_bytes_to_write_to_openssl += bytes_written_to_openssl;
1803     }
1804   }
1805   if (status != TSI_OK) return status;
1806   // Get bytes to send to the peer, if available.
1807   status = ssl_handshaker_write_output_buffer(self, &bytes_written, error);
1808   if (status != TSI_OK) return status;
1809   *bytes_to_send = impl->outgoing_bytes_buffer;
1810   *bytes_to_send_size = bytes_written;
1811   // If handshake completes, create tsi_handshaker_result.
1812   if (ssl_handshaker_get_result(impl) == TSI_HANDSHAKE_IN_PROGRESS) {
1813     *handshaker_result = nullptr;
1814   } else {
1815     // Any bytes that remain in |impl->ssl|'s read BIO after the handshake is
1816     // complete must be extracted and set to the unused bytes of the
1817     // handshaker result. This indicates to the gRPC stack that there are
1818     // bytes from the peer that must be processed.
1819     unsigned char* unused_bytes = nullptr;
1820     size_t unused_bytes_size = 0;
1821     status =
1822         ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size, error);
1823     if (status != TSI_OK) return status;
1824     if (unused_bytes_size > received_bytes_size) {
1825       gpr_log(GPR_ERROR, "More unused bytes than received bytes.");
1826       gpr_free(unused_bytes);
1827       if (error != nullptr) *error = "More unused bytes than received bytes.";
1828       return TSI_INTERNAL_ERROR;
1829     }
1830     status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size,
1831                                           handshaker_result, error);
1832     if (status == TSI_OK) {
1833       // Indicates that the handshake has completed and that a
1834       // handshaker_result has been created.
1835       self->handshaker_result_created = true;
1836     }
1837   }
1838   return status;
1839 }
1840 
1841 static const tsi_handshaker_vtable handshaker_vtable = {
1842     nullptr,  // get_bytes_to_send_to_peer -- deprecated
1843     nullptr,  // process_bytes_from_peer   -- deprecated
1844     nullptr,  // get_result                -- deprecated
1845     nullptr,  // extract_peer              -- deprecated
1846     nullptr,  // create_frame_protector    -- deprecated
1847     ssl_handshaker_destroy,
1848     ssl_handshaker_next,
1849     nullptr,  // shutdown
1850 };
1851 
1852 // --- tsi_ssl_handshaker_factory common methods. ---
1853 
tsi_ssl_handshaker_resume_session(SSL * ssl,tsi::SslSessionLRUCache * session_cache)1854 static void tsi_ssl_handshaker_resume_session(
1855     SSL* ssl, tsi::SslSessionLRUCache* session_cache) {
1856   const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1857   if (server_name == nullptr) {
1858     return;
1859   }
1860   tsi::SslSessionPtr session = session_cache->Get(server_name);
1861   if (session != nullptr) {
1862     // SSL_set_session internally increments reference counter.
1863     SSL_set_session(ssl, session.get());
1864   }
1865 }
1866 
create_tsi_ssl_handshaker(SSL_CTX * ctx,int is_client,const char * server_name_indication,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_ssl_handshaker_factory * factory,tsi_handshaker ** handshaker)1867 static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client,
1868                                             const char* server_name_indication,
1869                                             size_t network_bio_buf_size,
1870                                             size_t ssl_bio_buf_size,
1871                                             tsi_ssl_handshaker_factory* factory,
1872                                             tsi_handshaker** handshaker) {
1873   SSL* ssl = SSL_new(ctx);
1874   BIO* network_io = nullptr;
1875   BIO* ssl_io = nullptr;
1876   tsi_ssl_handshaker* impl = nullptr;
1877   *handshaker = nullptr;
1878   if (ctx == nullptr) {
1879     gpr_log(GPR_ERROR, "SSL Context is null. Should never happen.");
1880     return TSI_INTERNAL_ERROR;
1881   }
1882   if (ssl == nullptr) {
1883     return TSI_OUT_OF_RESOURCES;
1884   }
1885   SSL_set_info_callback(ssl, ssl_info_callback);
1886 
1887   if (!BIO_new_bio_pair(&network_io, network_bio_buf_size, &ssl_io,
1888                         ssl_bio_buf_size)) {
1889     gpr_log(GPR_ERROR, "BIO_new_bio_pair failed.");
1890     SSL_free(ssl);
1891     return TSI_OUT_OF_RESOURCES;
1892   }
1893   SSL_set_bio(ssl, ssl_io, ssl_io);
1894   if (is_client) {
1895     int ssl_result;
1896     SSL_set_connect_state(ssl);
1897     // Skip if the SNI looks like an IP address because IP addressed are not
1898     // allowed as host names.
1899     if (server_name_indication != nullptr &&
1900         !looks_like_ip_address(server_name_indication)) {
1901       if (!SSL_set_tlsext_host_name(ssl, server_name_indication)) {
1902         gpr_log(GPR_ERROR, "Invalid server name indication %s.",
1903                 server_name_indication);
1904         SSL_free(ssl);
1905         BIO_free(network_io);
1906         return TSI_INTERNAL_ERROR;
1907       }
1908     }
1909     tsi_ssl_client_handshaker_factory* client_factory =
1910         reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
1911     if (client_factory->session_cache != nullptr) {
1912       tsi_ssl_handshaker_resume_session(ssl,
1913                                         client_factory->session_cache.get());
1914     }
1915     ERR_clear_error();
1916     ssl_result = SSL_do_handshake(ssl);
1917     ssl_result = SSL_get_error(ssl, ssl_result);
1918     if (ssl_result != SSL_ERROR_WANT_READ) {
1919       gpr_log(GPR_ERROR,
1920               "Unexpected error received from first SSL_do_handshake call: %s",
1921               grpc_core::SslErrorString(ssl_result));
1922       SSL_free(ssl);
1923       BIO_free(network_io);
1924       return TSI_INTERNAL_ERROR;
1925     }
1926   } else {
1927     SSL_set_accept_state(ssl);
1928   }
1929 
1930   impl = grpc_core::Zalloc<tsi_ssl_handshaker>();
1931   impl->ssl = ssl;
1932   impl->network_io = network_io;
1933   impl->result = TSI_HANDSHAKE_IN_PROGRESS;
1934   impl->outgoing_bytes_buffer_size =
1935       TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
1936   impl->outgoing_bytes_buffer =
1937       static_cast<unsigned char*>(gpr_zalloc(impl->outgoing_bytes_buffer_size));
1938   impl->base.vtable = &handshaker_vtable;
1939   impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);
1940   *handshaker = &impl->base;
1941   return TSI_OK;
1942 }
1943 
select_protocol_list(const unsigned char ** out,unsigned char * outlen,const unsigned char * client_list,size_t client_list_len,const unsigned char * server_list,size_t server_list_len)1944 static int select_protocol_list(const unsigned char** out,
1945                                 unsigned char* outlen,
1946                                 const unsigned char* client_list,
1947                                 size_t client_list_len,
1948                                 const unsigned char* server_list,
1949                                 size_t server_list_len) {
1950   const unsigned char* client_current = client_list;
1951   while (static_cast<unsigned int>(client_current - client_list) <
1952          client_list_len) {
1953     unsigned char client_current_len = *(client_current++);
1954     const unsigned char* server_current = server_list;
1955     while ((server_current >= server_list) &&
1956            static_cast<uintptr_t>(server_current - server_list) <
1957                server_list_len) {
1958       unsigned char server_current_len = *(server_current++);
1959       if ((client_current_len == server_current_len) &&
1960           !memcmp(client_current, server_current, server_current_len)) {
1961         *out = server_current;
1962         *outlen = server_current_len;
1963         return SSL_TLSEXT_ERR_OK;
1964       }
1965       server_current += server_current_len;
1966     }
1967     client_current += client_current_len;
1968   }
1969   return SSL_TLSEXT_ERR_NOACK;
1970 }
1971 
1972 // --- tsi_ssl_client_handshaker_factory methods implementation. ---
1973 
tsi_ssl_client_handshaker_factory_create_handshaker(tsi_ssl_client_handshaker_factory * factory,const char * server_name_indication,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_handshaker ** handshaker)1974 tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
1975     tsi_ssl_client_handshaker_factory* factory,
1976     const char* server_name_indication, size_t network_bio_buf_size,
1977     size_t ssl_bio_buf_size, tsi_handshaker** handshaker) {
1978   return create_tsi_ssl_handshaker(
1979       factory->ssl_context, 1, server_name_indication, network_bio_buf_size,
1980       ssl_bio_buf_size, &factory->base, handshaker);
1981 }
1982 
tsi_ssl_client_handshaker_factory_unref(tsi_ssl_client_handshaker_factory * factory)1983 void tsi_ssl_client_handshaker_factory_unref(
1984     tsi_ssl_client_handshaker_factory* factory) {
1985   if (factory == nullptr) return;
1986   tsi_ssl_handshaker_factory_unref(&factory->base);
1987 }
1988 
tsi_ssl_client_handshaker_factory_ref(tsi_ssl_client_handshaker_factory * client_factory)1989 tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref(
1990     tsi_ssl_client_handshaker_factory* client_factory) {
1991   if (client_factory == nullptr) return nullptr;
1992   return reinterpret_cast<tsi_ssl_client_handshaker_factory*>(
1993       tsi_ssl_handshaker_factory_ref(&client_factory->base));
1994 }
1995 
tsi_ssl_client_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1996 static void tsi_ssl_client_handshaker_factory_destroy(
1997     tsi_ssl_handshaker_factory* factory) {
1998   if (factory == nullptr) return;
1999   tsi_ssl_client_handshaker_factory* self =
2000       reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
2001   if (self->ssl_context != nullptr) SSL_CTX_free(self->ssl_context);
2002   if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
2003   self->session_cache.reset();
2004   self->key_logger.reset();
2005   gpr_free(self);
2006 }
2007 
client_handshaker_factory_npn_callback(SSL *,unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)2008 static int client_handshaker_factory_npn_callback(
2009     SSL* /*ssl*/, unsigned char** out, unsigned char* outlen,
2010     const unsigned char* in, unsigned int inlen, void* arg) {
2011   tsi_ssl_client_handshaker_factory* factory =
2012       static_cast<tsi_ssl_client_handshaker_factory*>(arg);
2013   return select_protocol_list(const_cast<const unsigned char**>(out), outlen,
2014                               factory->alpn_protocol_list,
2015                               factory->alpn_protocol_list_length, in, inlen);
2016 }
2017 
2018 // --- tsi_ssl_server_handshaker_factory methods implementation. ---
2019 
tsi_ssl_server_handshaker_factory_create_handshaker(tsi_ssl_server_handshaker_factory * factory,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_handshaker ** handshaker)2020 tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
2021     tsi_ssl_server_handshaker_factory* factory, size_t network_bio_buf_size,
2022     size_t ssl_bio_buf_size, tsi_handshaker** handshaker) {
2023   if (factory->ssl_context_count == 0) return TSI_INVALID_ARGUMENT;
2024   // Create the handshaker with the first context. We will switch if needed
2025   // because of SNI in ssl_server_handshaker_factory_servername_callback.
2026   return create_tsi_ssl_handshaker(factory->ssl_contexts[0], 0, nullptr,
2027                                    network_bio_buf_size, ssl_bio_buf_size,
2028                                    &factory->base, handshaker);
2029 }
2030 
tsi_ssl_server_handshaker_factory_unref(tsi_ssl_server_handshaker_factory * factory)2031 void tsi_ssl_server_handshaker_factory_unref(
2032     tsi_ssl_server_handshaker_factory* factory) {
2033   if (factory == nullptr) return;
2034   tsi_ssl_handshaker_factory_unref(&factory->base);
2035 }
2036 
tsi_ssl_server_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)2037 static void tsi_ssl_server_handshaker_factory_destroy(
2038     tsi_ssl_handshaker_factory* factory) {
2039   if (factory == nullptr) return;
2040   tsi_ssl_server_handshaker_factory* self =
2041       reinterpret_cast<tsi_ssl_server_handshaker_factory*>(factory);
2042   size_t i;
2043   for (i = 0; i < self->ssl_context_count; i++) {
2044     if (self->ssl_contexts[i] != nullptr) {
2045       SSL_CTX_free(self->ssl_contexts[i]);
2046       tsi_peer_destruct(&self->ssl_context_x509_subject_names[i]);
2047     }
2048   }
2049   if (self->ssl_contexts != nullptr) gpr_free(self->ssl_contexts);
2050   if (self->ssl_context_x509_subject_names != nullptr) {
2051     gpr_free(self->ssl_context_x509_subject_names);
2052   }
2053   if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
2054   self->key_logger.reset();
2055   gpr_free(self);
2056 }
2057 
does_entry_match_name(absl::string_view entry,absl::string_view name)2058 static int does_entry_match_name(absl::string_view entry,
2059                                  absl::string_view name) {
2060   if (entry.empty()) return 0;
2061 
2062   // Take care of '.' terminations.
2063   if (name.back() == '.') {
2064     name.remove_suffix(1);
2065   }
2066   if (entry.back() == '.') {
2067     entry.remove_suffix(1);
2068     if (entry.empty()) return 0;
2069   }
2070 
2071   if (absl::EqualsIgnoreCase(name, entry)) {
2072     return 1;  // Perfect match.
2073   }
2074   if (entry.front() != '*') return 0;
2075 
2076   // Wildchar subdomain matching.
2077   if (entry.size() < 3 || entry[1] != '.') {  // At least *.x
2078     gpr_log(GPR_ERROR, "Invalid wildchar entry.");
2079     return 0;
2080   }
2081   size_t name_subdomain_pos = name.find('.');
2082   if (name_subdomain_pos == absl::string_view::npos) return 0;
2083   if (name_subdomain_pos >= name.size() - 2) return 0;
2084   absl::string_view name_subdomain =
2085       name.substr(name_subdomain_pos + 1);  // Starts after the dot.
2086   entry.remove_prefix(2);                   // Remove *.
2087   size_t dot = name_subdomain.find('.');
2088   if (dot == absl::string_view::npos || dot == name_subdomain.size() - 1) {
2089     gpr_log(GPR_ERROR, "Invalid toplevel subdomain: %s",
2090             std::string(name_subdomain).c_str());
2091     return 0;
2092   }
2093   if (name_subdomain.back() == '.') {
2094     name_subdomain.remove_suffix(1);
2095   }
2096   return !entry.empty() && absl::EqualsIgnoreCase(name_subdomain, entry);
2097 }
2098 
ssl_server_handshaker_factory_servername_callback(SSL * ssl,int *,void * arg)2099 static int ssl_server_handshaker_factory_servername_callback(SSL* ssl,
2100                                                              int* /*ap*/,
2101                                                              void* arg) {
2102   tsi_ssl_server_handshaker_factory* impl =
2103       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
2104   size_t i = 0;
2105   const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
2106   if (servername == nullptr || strlen(servername) == 0) {
2107     return SSL_TLSEXT_ERR_NOACK;
2108   }
2109 
2110   for (i = 0; i < impl->ssl_context_count; i++) {
2111     if (tsi_ssl_peer_matches_name(&impl->ssl_context_x509_subject_names[i],
2112                                   servername)) {
2113       SSL_set_SSL_CTX(ssl, impl->ssl_contexts[i]);
2114       return SSL_TLSEXT_ERR_OK;
2115     }
2116   }
2117   gpr_log(GPR_ERROR, "No match found for server name: %s.", servername);
2118   return SSL_TLSEXT_ERR_NOACK;
2119 }
2120 
2121 #if TSI_OPENSSL_ALPN_SUPPORT
server_handshaker_factory_alpn_callback(SSL *,const unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)2122 static int server_handshaker_factory_alpn_callback(
2123     SSL* /*ssl*/, const unsigned char** out, unsigned char* outlen,
2124     const unsigned char* in, unsigned int inlen, void* arg) {
2125   tsi_ssl_server_handshaker_factory* factory =
2126       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
2127   return select_protocol_list(out, outlen, in, inlen,
2128                               factory->alpn_protocol_list,
2129                               factory->alpn_protocol_list_length);
2130 }
2131 #endif  // TSI_OPENSSL_ALPN_SUPPORT
2132 
server_handshaker_factory_npn_advertised_callback(SSL *,const unsigned char ** out,unsigned int * outlen,void * arg)2133 static int server_handshaker_factory_npn_advertised_callback(
2134     SSL* /*ssl*/, const unsigned char** out, unsigned int* outlen, void* arg) {
2135   tsi_ssl_server_handshaker_factory* factory =
2136       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
2137   *out = factory->alpn_protocol_list;
2138   GPR_ASSERT(factory->alpn_protocol_list_length <= UINT_MAX);
2139   *outlen = static_cast<unsigned int>(factory->alpn_protocol_list_length);
2140   return SSL_TLSEXT_ERR_OK;
2141 }
2142 
2143 /// This callback is called when new \a session is established and ready to
2144 /// be cached. This session can be reused for new connections to similar
2145 /// servers at later point of time.
2146 /// It's intended to be used with SSL_CTX_sess_set_new_cb function.
2147 ///
2148 /// It returns 1 if callback takes ownership over \a session and 0 otherwise.
server_handshaker_factory_new_session_callback(SSL * ssl,SSL_SESSION * session)2149 static int server_handshaker_factory_new_session_callback(
2150     SSL* ssl, SSL_SESSION* session) {
2151   SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
2152   if (ssl_context == nullptr) {
2153     return 0;
2154   }
2155   void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
2156   tsi_ssl_client_handshaker_factory* factory =
2157       static_cast<tsi_ssl_client_handshaker_factory*>(arg);
2158   const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
2159   if (server_name == nullptr) {
2160     return 0;
2161   }
2162   factory->session_cache->Put(server_name, tsi::SslSessionPtr(session));
2163   // Return 1 to indicate transferred ownership over the given session.
2164   return 1;
2165 }
2166 
2167 /// This callback is invoked at client or server when ssl/tls handshakes
2168 /// complete and keylogging is enabled.
2169 template <typename T>
ssl_keylogging_callback(const SSL * ssl,const char * info)2170 static void ssl_keylogging_callback(const SSL* ssl, const char* info) {
2171   SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
2172   GPR_ASSERT(ssl_context != nullptr);
2173   void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
2174   T* factory = static_cast<T*>(arg);
2175   factory->key_logger->LogSessionKeys(ssl_context, info);
2176 }
2177 
2178 // --- tsi_ssl_handshaker_factory constructors. ---
2179 
2180 static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = {
2181     tsi_ssl_client_handshaker_factory_destroy};
2182 
tsi_create_ssl_client_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pair,const char * pem_root_certs,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_client_handshaker_factory ** factory)2183 tsi_result tsi_create_ssl_client_handshaker_factory(
2184     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair,
2185     const char* pem_root_certs, const char* cipher_suites,
2186     const char** alpn_protocols, uint16_t num_alpn_protocols,
2187     tsi_ssl_client_handshaker_factory** factory) {
2188   tsi_ssl_client_handshaker_options options;
2189   options.pem_key_cert_pair = pem_key_cert_pair;
2190   options.pem_root_certs = pem_root_certs;
2191   options.cipher_suites = cipher_suites;
2192   options.alpn_protocols = alpn_protocols;
2193   options.num_alpn_protocols = num_alpn_protocols;
2194   return tsi_create_ssl_client_handshaker_factory_with_options(&options,
2195                                                                factory);
2196 }
2197 
tsi_create_ssl_client_handshaker_factory_with_options(const tsi_ssl_client_handshaker_options * options,tsi_ssl_client_handshaker_factory ** factory)2198 tsi_result tsi_create_ssl_client_handshaker_factory_with_options(
2199     const tsi_ssl_client_handshaker_options* options,
2200     tsi_ssl_client_handshaker_factory** factory) {
2201   SSL_CTX* ssl_context = nullptr;
2202   tsi_ssl_client_handshaker_factory* impl = nullptr;
2203   tsi_result result = TSI_OK;
2204 
2205   gpr_once_init(&g_init_openssl_once, init_openssl);
2206 
2207   if (factory == nullptr) return TSI_INVALID_ARGUMENT;
2208   *factory = nullptr;
2209   if (options->pem_root_certs == nullptr && options->root_store == nullptr &&
2210       !options->skip_server_certificate_verification) {
2211     return TSI_INVALID_ARGUMENT;
2212   }
2213 
2214 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2215   ssl_context = SSL_CTX_new(TLS_method());
2216 #else
2217   ssl_context = SSL_CTX_new(TLSv1_2_method());
2218 #endif
2219 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2220   SSL_CTX_set_options(ssl_context, SSL_OP_NO_RENEGOTIATION);
2221 #endif
2222   if (ssl_context == nullptr) {
2223     grpc_core::LogSslErrorStack();
2224     gpr_log(GPR_ERROR, "Could not create ssl context.");
2225     return TSI_INVALID_ARGUMENT;
2226   }
2227 
2228   result = tsi_set_min_and_max_tls_versions(
2229       ssl_context, options->min_tls_version, options->max_tls_version);
2230   if (result != TSI_OK) return result;
2231 
2232   impl = static_cast<tsi_ssl_client_handshaker_factory*>(
2233       gpr_zalloc(sizeof(*impl)));
2234   tsi_ssl_handshaker_factory_init(&impl->base);
2235   impl->base.vtable = &client_handshaker_factory_vtable;
2236   impl->ssl_context = ssl_context;
2237   if (options->session_cache != nullptr) {
2238     // Unref is called manually on factory destruction.
2239     impl->session_cache =
2240         reinterpret_cast<tsi::SslSessionLRUCache*>(options->session_cache)
2241             ->Ref();
2242     SSL_CTX_sess_set_new_cb(ssl_context,
2243                             server_handshaker_factory_new_session_callback);
2244     SSL_CTX_set_session_cache_mode(ssl_context, SSL_SESS_CACHE_CLIENT);
2245   }
2246 
2247 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2248   if (options->key_logger != nullptr) {
2249     impl->key_logger = options->key_logger->Ref();
2250     // SSL_CTX_set_keylog_callback is set here to register callback
2251     // when ssl/tls handshakes complete.
2252     SSL_CTX_set_keylog_callback(
2253         ssl_context,
2254         ssl_keylogging_callback<tsi_ssl_client_handshaker_factory>);
2255   }
2256 #endif
2257 
2258   if (options->session_cache != nullptr || options->key_logger != nullptr) {
2259     // Need to set factory at g_ssl_ctx_ex_factory_index
2260     SSL_CTX_set_ex_data(ssl_context, g_ssl_ctx_ex_factory_index, impl);
2261   }
2262 
2263   do {
2264     result = populate_ssl_context(ssl_context, options->pem_key_cert_pair,
2265                                   options->cipher_suites);
2266     if (result != TSI_OK) break;
2267 
2268 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2269     // X509_STORE_up_ref is only available since OpenSSL 1.1.
2270     if (options->root_store != nullptr) {
2271       X509_STORE_up_ref(options->root_store->store);
2272       SSL_CTX_set_cert_store(ssl_context, options->root_store->store);
2273     }
2274 #endif
2275     if (OPENSSL_VERSION_NUMBER < 0x10100000 ||
2276         (options->root_store == nullptr &&
2277          options->pem_root_certs != nullptr)) {
2278       result = ssl_ctx_load_verification_certs(
2279           ssl_context, options->pem_root_certs, strlen(options->pem_root_certs),
2280           nullptr);
2281       X509_STORE* cert_store = SSL_CTX_get_cert_store(ssl_context);
2282 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2283       X509_VERIFY_PARAM* param = X509_STORE_get0_param(cert_store);
2284 
2285 #else
2286       X509_VERIFY_PARAM* param = cert_store->param;
2287 #endif
2288 
2289       X509_VERIFY_PARAM_set_depth(param, kMaxChainLength);
2290       if (result != TSI_OK) {
2291         gpr_log(GPR_ERROR, "Cannot load server root certificates.");
2292         break;
2293       }
2294     }
2295 
2296     if (options->num_alpn_protocols != 0) {
2297       result = build_alpn_protocol_name_list(
2298           options->alpn_protocols, options->num_alpn_protocols,
2299           &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
2300       if (result != TSI_OK) {
2301         gpr_log(GPR_ERROR, "Building alpn list failed with error %s.",
2302                 tsi_result_to_string(result));
2303         break;
2304       }
2305 #if TSI_OPENSSL_ALPN_SUPPORT
2306       GPR_ASSERT(impl->alpn_protocol_list_length < UINT_MAX);
2307       if (SSL_CTX_set_alpn_protos(
2308               ssl_context, impl->alpn_protocol_list,
2309               static_cast<unsigned int>(impl->alpn_protocol_list_length))) {
2310         gpr_log(GPR_ERROR, "Could not set alpn protocol list to context.");
2311         result = TSI_INVALID_ARGUMENT;
2312         break;
2313       }
2314 #endif  // TSI_OPENSSL_ALPN_SUPPORT
2315       SSL_CTX_set_next_proto_select_cb(
2316           ssl_context, client_handshaker_factory_npn_callback, impl);
2317     }
2318   } while (false);
2319   if (result != TSI_OK) {
2320     tsi_ssl_handshaker_factory_unref(&impl->base);
2321     return result;
2322   }
2323   SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, nullptr);
2324   if (options->skip_server_certificate_verification) {
2325     SSL_CTX_set_cert_verify_callback(ssl_context, NullVerifyCallback, nullptr);
2326   } else {
2327     SSL_CTX_set_cert_verify_callback(ssl_context, CustomVerificationFunction,
2328                                      nullptr);
2329   }
2330 #if OPENSSL_VERSION_NUMBER >= 0x10100000 && !defined(LIBRESSL_VERSION_NUMBER)
2331   if (options->crl_provider != nullptr) {
2332     SSL_CTX_set_ex_data(impl->ssl_context, g_ssl_ctx_ex_crl_provider_index,
2333                         options->crl_provider.get());
2334   } else if (options->crl_directory != nullptr &&
2335              strcmp(options->crl_directory, "") != 0) {
2336     X509_STORE* cert_store = SSL_CTX_get_cert_store(ssl_context);
2337     X509_STORE_set_verify_cb(cert_store, verify_cb);
2338     if (!X509_STORE_load_locations(cert_store, nullptr,
2339                                    options->crl_directory)) {
2340       gpr_log(GPR_ERROR, "Failed to load CRL File from directory.");
2341     } else {
2342       X509_VERIFY_PARAM* param = X509_STORE_get0_param(cert_store);
2343       X509_VERIFY_PARAM_set_flags(
2344           param, X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL);
2345     }
2346   }
2347 #endif
2348 
2349   *factory = impl;
2350   return TSI_OK;
2351 }
2352 
2353 static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = {
2354     tsi_ssl_server_handshaker_factory_destroy};
2355 
tsi_create_ssl_server_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,int force_client_auth,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)2356 tsi_result tsi_create_ssl_server_handshaker_factory(
2357     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
2358     size_t num_key_cert_pairs, const char* pem_client_root_certs,
2359     int force_client_auth, const char* cipher_suites,
2360     const char** alpn_protocols, uint16_t num_alpn_protocols,
2361     tsi_ssl_server_handshaker_factory** factory) {
2362   return tsi_create_ssl_server_handshaker_factory_ex(
2363       pem_key_cert_pairs, num_key_cert_pairs, pem_client_root_certs,
2364       force_client_auth ? TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
2365                         : TSI_DONT_REQUEST_CLIENT_CERTIFICATE,
2366       cipher_suites, alpn_protocols, num_alpn_protocols, factory);
2367 }
2368 
tsi_create_ssl_server_handshaker_factory_ex(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,tsi_client_certificate_request_type client_certificate_request,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)2369 tsi_result tsi_create_ssl_server_handshaker_factory_ex(
2370     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
2371     size_t num_key_cert_pairs, const char* pem_client_root_certs,
2372     tsi_client_certificate_request_type client_certificate_request,
2373     const char* cipher_suites, const char** alpn_protocols,
2374     uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory) {
2375   tsi_ssl_server_handshaker_options options;
2376   options.pem_key_cert_pairs = pem_key_cert_pairs;
2377   options.num_key_cert_pairs = num_key_cert_pairs;
2378   options.pem_client_root_certs = pem_client_root_certs;
2379   options.client_certificate_request = client_certificate_request;
2380   options.cipher_suites = cipher_suites;
2381   options.alpn_protocols = alpn_protocols;
2382   options.num_alpn_protocols = num_alpn_protocols;
2383   return tsi_create_ssl_server_handshaker_factory_with_options(&options,
2384                                                                factory);
2385 }
2386 
tsi_create_ssl_server_handshaker_factory_with_options(const tsi_ssl_server_handshaker_options * options,tsi_ssl_server_handshaker_factory ** factory)2387 tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
2388     const tsi_ssl_server_handshaker_options* options,
2389     tsi_ssl_server_handshaker_factory** factory) {
2390   tsi_ssl_server_handshaker_factory* impl = nullptr;
2391   tsi_result result = TSI_OK;
2392   size_t i = 0;
2393 
2394   gpr_once_init(&g_init_openssl_once, init_openssl);
2395 
2396   if (factory == nullptr) return TSI_INVALID_ARGUMENT;
2397   *factory = nullptr;
2398   if (options->num_key_cert_pairs == 0 ||
2399       options->pem_key_cert_pairs == nullptr) {
2400     return TSI_INVALID_ARGUMENT;
2401   }
2402 
2403   impl = static_cast<tsi_ssl_server_handshaker_factory*>(
2404       gpr_zalloc(sizeof(*impl)));
2405   tsi_ssl_handshaker_factory_init(&impl->base);
2406   impl->base.vtable = &server_handshaker_factory_vtable;
2407 
2408   impl->ssl_contexts = static_cast<SSL_CTX**>(
2409       gpr_zalloc(options->num_key_cert_pairs * sizeof(SSL_CTX*)));
2410   impl->ssl_context_x509_subject_names = static_cast<tsi_peer*>(
2411       gpr_zalloc(options->num_key_cert_pairs * sizeof(tsi_peer)));
2412   if (impl->ssl_contexts == nullptr ||
2413       impl->ssl_context_x509_subject_names == nullptr) {
2414     tsi_ssl_handshaker_factory_unref(&impl->base);
2415     return TSI_OUT_OF_RESOURCES;
2416   }
2417   impl->ssl_context_count = options->num_key_cert_pairs;
2418 
2419   if (options->num_alpn_protocols > 0) {
2420     result = build_alpn_protocol_name_list(
2421         options->alpn_protocols, options->num_alpn_protocols,
2422         &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
2423     if (result != TSI_OK) {
2424       tsi_ssl_handshaker_factory_unref(&impl->base);
2425       return result;
2426     }
2427   }
2428 
2429   if (options->key_logger != nullptr) {
2430     impl->key_logger = options->key_logger->Ref();
2431   }
2432 
2433   for (i = 0; i < options->num_key_cert_pairs; i++) {
2434     do {
2435 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2436       impl->ssl_contexts[i] = SSL_CTX_new(TLS_method());
2437 #else
2438       impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method());
2439 #endif
2440 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2441       SSL_CTX_set_options(impl->ssl_contexts[i], SSL_OP_NO_RENEGOTIATION);
2442 #endif
2443       if (impl->ssl_contexts[i] == nullptr) {
2444         grpc_core::LogSslErrorStack();
2445         gpr_log(GPR_ERROR, "Could not create ssl context.");
2446         result = TSI_OUT_OF_RESOURCES;
2447         break;
2448       }
2449 
2450       result = tsi_set_min_and_max_tls_versions(impl->ssl_contexts[i],
2451                                                 options->min_tls_version,
2452                                                 options->max_tls_version);
2453       if (result != TSI_OK) return result;
2454 
2455       result = populate_ssl_context(impl->ssl_contexts[i],
2456                                     &options->pem_key_cert_pairs[i],
2457                                     options->cipher_suites);
2458       if (result != TSI_OK) break;
2459 
2460       // TODO(elessar): Provide ability to disable session ticket keys.
2461 
2462       // Allow client cache sessions (it's needed for OpenSSL only).
2463       int set_sid_ctx_result = SSL_CTX_set_session_id_context(
2464           impl->ssl_contexts[i], kSslSessionIdContext,
2465           GPR_ARRAY_SIZE(kSslSessionIdContext));
2466       if (set_sid_ctx_result == 0) {
2467         gpr_log(GPR_ERROR, "Failed to set session id context.");
2468         result = TSI_INTERNAL_ERROR;
2469         break;
2470       }
2471 
2472       if (options->session_ticket_key != nullptr) {
2473         if (SSL_CTX_set_tlsext_ticket_keys(
2474                 impl->ssl_contexts[i],
2475                 const_cast<char*>(options->session_ticket_key),
2476                 options->session_ticket_key_size) == 0) {
2477           gpr_log(GPR_ERROR, "Invalid STEK size.");
2478           result = TSI_INVALID_ARGUMENT;
2479           break;
2480         }
2481       }
2482 
2483       if (options->pem_client_root_certs != nullptr) {
2484         STACK_OF(X509_NAME)* root_names = nullptr;
2485         result = ssl_ctx_load_verification_certs(
2486             impl->ssl_contexts[i], options->pem_client_root_certs,
2487             strlen(options->pem_client_root_certs),
2488             options->send_client_ca_list ? &root_names : nullptr);
2489         if (result != TSI_OK) {
2490           gpr_log(GPR_ERROR, "Invalid verification certs.");
2491           break;
2492         }
2493         if (options->send_client_ca_list) {
2494           SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names);
2495         }
2496       }
2497       switch (options->client_certificate_request) {
2498         case TSI_DONT_REQUEST_CLIENT_CERTIFICATE:
2499           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr);
2500           break;
2501         case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
2502           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
2503           SSL_CTX_set_cert_verify_callback(impl->ssl_contexts[i],
2504                                            NullVerifyCallback, nullptr);
2505           break;
2506         case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
2507           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
2508           SSL_CTX_set_cert_verify_callback(impl->ssl_contexts[i],
2509                                            CustomVerificationFunction, nullptr);
2510           break;
2511         case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
2512           SSL_CTX_set_verify(impl->ssl_contexts[i],
2513                              SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
2514                              nullptr);
2515           SSL_CTX_set_cert_verify_callback(impl->ssl_contexts[i],
2516                                            NullVerifyCallback, nullptr);
2517           break;
2518         case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
2519           SSL_CTX_set_verify(impl->ssl_contexts[i],
2520                              SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
2521                              nullptr);
2522           SSL_CTX_set_cert_verify_callback(impl->ssl_contexts[i],
2523                                            CustomVerificationFunction, nullptr);
2524           break;
2525       }
2526 
2527 #if OPENSSL_VERSION_NUMBER >= 0x10100000 && !defined(LIBRESSL_VERSION_NUMBER)
2528       if (options->crl_provider != nullptr) {
2529         SSL_CTX_set_ex_data(impl->ssl_contexts[i],
2530                             g_ssl_ctx_ex_crl_provider_index,
2531                             options->crl_provider.get());
2532       } else if (options->crl_directory != nullptr &&
2533                  strcmp(options->crl_directory, "") != 0) {
2534         X509_STORE* cert_store = SSL_CTX_get_cert_store(impl->ssl_contexts[i]);
2535         X509_STORE_set_verify_cb(cert_store, verify_cb);
2536         if (!X509_STORE_load_locations(cert_store, nullptr,
2537                                        options->crl_directory)) {
2538           gpr_log(GPR_ERROR, "Failed to load CRL File from directory.");
2539         } else {
2540           X509_VERIFY_PARAM* param = X509_STORE_get0_param(cert_store);
2541           X509_VERIFY_PARAM_set_flags(
2542               param, X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL);
2543         }
2544       }
2545 #endif
2546 
2547       result = tsi_ssl_extract_x509_subject_names_from_pem_cert(
2548           options->pem_key_cert_pairs[i].cert_chain,
2549           &impl->ssl_context_x509_subject_names[i]);
2550       if (result != TSI_OK) break;
2551 
2552       SSL_CTX_set_tlsext_servername_callback(
2553           impl->ssl_contexts[i],
2554           ssl_server_handshaker_factory_servername_callback);
2555       SSL_CTX_set_tlsext_servername_arg(impl->ssl_contexts[i], impl);
2556 #if TSI_OPENSSL_ALPN_SUPPORT
2557       SSL_CTX_set_alpn_select_cb(impl->ssl_contexts[i],
2558                                  server_handshaker_factory_alpn_callback, impl);
2559 #endif  // TSI_OPENSSL_ALPN_SUPPORT
2560       SSL_CTX_set_next_protos_advertised_cb(
2561           impl->ssl_contexts[i],
2562           server_handshaker_factory_npn_advertised_callback, impl);
2563 
2564 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2565       // Register factory at index
2566       if (options->key_logger != nullptr) {
2567         // Need to set factory at g_ssl_ctx_ex_factory_index
2568         SSL_CTX_set_ex_data(impl->ssl_contexts[i], g_ssl_ctx_ex_factory_index,
2569                             impl);
2570         // SSL_CTX_set_keylog_callback is set here to register callback
2571         // when ssl/tls handshakes complete.
2572         SSL_CTX_set_keylog_callback(
2573             impl->ssl_contexts[i],
2574             ssl_keylogging_callback<tsi_ssl_server_handshaker_factory>);
2575       }
2576 #endif
2577     } while (false);
2578 
2579     if (result != TSI_OK) {
2580       tsi_ssl_handshaker_factory_unref(&impl->base);
2581       return result;
2582     }
2583   }
2584 
2585   *factory = impl;
2586   return TSI_OK;
2587 }
2588 
2589 // --- tsi_ssl utils. ---
2590 
tsi_ssl_peer_matches_name(const tsi_peer * peer,absl::string_view name)2591 int tsi_ssl_peer_matches_name(const tsi_peer* peer, absl::string_view name) {
2592   size_t i = 0;
2593   size_t san_count = 0;
2594   const tsi_peer_property* cn_property = nullptr;
2595   int like_ip = looks_like_ip_address(name);
2596 
2597   // Check the SAN first.
2598   for (i = 0; i < peer->property_count; i++) {
2599     const tsi_peer_property* property = &peer->properties[i];
2600     if (property->name == nullptr) continue;
2601     if (strcmp(property->name,
2602                TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) {
2603       san_count++;
2604 
2605       absl::string_view entry(property->value.data, property->value.length);
2606       if (!like_ip && does_entry_match_name(entry, name)) {
2607         return 1;
2608       } else if (like_ip && name == entry) {
2609         // IP Addresses are exact matches only.
2610         return 1;
2611       }
2612     } else if (strcmp(property->name,
2613                       TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY) == 0) {
2614       cn_property = property;
2615     }
2616   }
2617 
2618   // If there's no SAN, try the CN, but only if its not like an IP Address
2619   if (san_count == 0 && cn_property != nullptr && !like_ip) {
2620     if (does_entry_match_name(absl::string_view(cn_property->value.data,
2621                                                 cn_property->value.length),
2622                               name)) {
2623       return 1;
2624     }
2625   }
2626 
2627   return 0;  // Not found.
2628 }
2629 
2630 // --- Testing support. ---
tsi_ssl_handshaker_factory_swap_vtable(tsi_ssl_handshaker_factory * factory,tsi_ssl_handshaker_factory_vtable * new_vtable)2631 const tsi_ssl_handshaker_factory_vtable* tsi_ssl_handshaker_factory_swap_vtable(
2632     tsi_ssl_handshaker_factory* factory,
2633     tsi_ssl_handshaker_factory_vtable* new_vtable) {
2634   GPR_ASSERT(factory != nullptr);
2635   GPR_ASSERT(factory->vtable != nullptr);
2636 
2637   const tsi_ssl_handshaker_factory_vtable* orig_vtable = factory->vtable;
2638   factory->vtable = new_vtable;
2639   return orig_vtable;
2640 }
2641