xref: /aosp_15_r20/external/grpc-grpc/test/cpp/util/test_credentials_provider.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 
2 //
3 //
4 // Copyright 2016 gRPC authors.
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 //
10 //     http://www.apache.org/licenses/LICENSE-2.0
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 //
19 
20 #include "test/cpp/util/test_credentials_provider.h"
21 
22 #include <cstdio>
23 #include <fstream>
24 #include <iostream>
25 #include <mutex>
26 #include <unordered_map>
27 
28 #include "absl/flags/flag.h"
29 
30 #include <grpc/support/log.h>
31 #include <grpc/support/sync.h>
32 #include <grpcpp/security/server_credentials.h>
33 
34 #include "src/core/lib/gprpp/crash.h"
35 #include "test/core/end2end/data/ssl_test_data.h"
36 
37 ABSL_FLAG(std::string, tls_cert_file, "",
38           "The TLS cert file used when --use_tls=true");
39 ABSL_FLAG(std::string, tls_key_file, "",
40           "The TLS key file used when --use_tls=true");
41 
42 namespace grpc {
43 namespace testing {
44 namespace {
45 
ReadFile(const std::string & src_path)46 std::string ReadFile(const std::string& src_path) {
47   std::ifstream src;
48   src.open(src_path, std::ifstream::in | std::ifstream::binary);
49 
50   std::string contents;
51   src.seekg(0, std::ios::end);
52   contents.reserve(src.tellg());
53   src.seekg(0, std::ios::beg);
54   contents.assign((std::istreambuf_iterator<char>(src)),
55                   (std::istreambuf_iterator<char>()));
56   return contents;
57 }
58 
59 class DefaultCredentialsProvider : public CredentialsProvider {
60  public:
DefaultCredentialsProvider()61   DefaultCredentialsProvider() {
62     if (!absl::GetFlag(FLAGS_tls_key_file).empty()) {
63       custom_server_key_ = ReadFile(absl::GetFlag(FLAGS_tls_key_file));
64     }
65     if (!absl::GetFlag(FLAGS_tls_cert_file).empty()) {
66       custom_server_cert_ = ReadFile(absl::GetFlag(FLAGS_tls_cert_file));
67     }
68   }
~DefaultCredentialsProvider()69   ~DefaultCredentialsProvider() override {}
70 
AddSecureType(const std::string & type,std::unique_ptr<CredentialTypeProvider> type_provider)71   void AddSecureType(
72       const std::string& type,
73       std::unique_ptr<CredentialTypeProvider> type_provider) override {
74     // This clobbers any existing entry for type, except the defaults, which
75     // can't be clobbered.
76     std::unique_lock<std::mutex> lock(mu_);
77     auto it = std::find(added_secure_type_names_.begin(),
78                         added_secure_type_names_.end(), type);
79     if (it == added_secure_type_names_.end()) {
80       added_secure_type_names_.push_back(type);
81       added_secure_type_providers_.push_back(std::move(type_provider));
82     } else {
83       added_secure_type_providers_[it - added_secure_type_names_.begin()] =
84           std::move(type_provider);
85     }
86   }
87 
GetChannelCredentials(const std::string & type,ChannelArguments * args)88   std::shared_ptr<ChannelCredentials> GetChannelCredentials(
89       const std::string& type, ChannelArguments* args) override {
90     if (type == grpc::testing::kInsecureCredentialsType) {
91       return InsecureChannelCredentials();
92     } else if (type == grpc::testing::kAltsCredentialsType) {
93       grpc::experimental::AltsCredentialsOptions alts_opts;
94       return grpc::experimental::AltsCredentials(alts_opts);
95     } else if (type == grpc::testing::kTlsCredentialsType) {
96       SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
97       args->SetSslTargetNameOverride("foo.test.google.fr");
98       return grpc::SslCredentials(ssl_opts);
99     } else if (type == grpc::testing::kGoogleDefaultCredentialsType) {
100       return grpc::GoogleDefaultCredentials();
101     } else {
102       std::unique_lock<std::mutex> lock(mu_);
103       auto it(std::find(added_secure_type_names_.begin(),
104                         added_secure_type_names_.end(), type));
105       if (it == added_secure_type_names_.end()) {
106         gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
107         return nullptr;
108       }
109       return added_secure_type_providers_[it - added_secure_type_names_.begin()]
110           ->GetChannelCredentials(args);
111     }
112   }
113 
GetServerCredentials(const std::string & type)114   std::shared_ptr<ServerCredentials> GetServerCredentials(
115       const std::string& type) override {
116     if (type == grpc::testing::kInsecureCredentialsType) {
117       return InsecureServerCredentials();
118     } else if (type == grpc::testing::kAltsCredentialsType) {
119       grpc::experimental::AltsServerCredentialsOptions alts_opts;
120       return grpc::experimental::AltsServerCredentials(alts_opts);
121     } else if (type == grpc::testing::kTlsCredentialsType) {
122       SslServerCredentialsOptions ssl_opts;
123       ssl_opts.pem_root_certs = "";
124       if (!custom_server_key_.empty() && !custom_server_cert_.empty()) {
125         SslServerCredentialsOptions::PemKeyCertPair pkcp = {
126             custom_server_key_, custom_server_cert_};
127         ssl_opts.pem_key_cert_pairs.push_back(pkcp);
128       } else {
129         SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key,
130                                                             test_server1_cert};
131         ssl_opts.pem_key_cert_pairs.push_back(pkcp);
132       }
133       return SslServerCredentials(ssl_opts);
134     } else {
135       std::unique_lock<std::mutex> lock(mu_);
136       auto it(std::find(added_secure_type_names_.begin(),
137                         added_secure_type_names_.end(), type));
138       if (it == added_secure_type_names_.end()) {
139         gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
140         return nullptr;
141       }
142       return added_secure_type_providers_[it - added_secure_type_names_.begin()]
143           ->GetServerCredentials();
144     }
145   }
GetSecureCredentialsTypeList()146   std::vector<std::string> GetSecureCredentialsTypeList() override {
147     std::vector<std::string> types;
148     types.push_back(grpc::testing::kTlsCredentialsType);
149     std::unique_lock<std::mutex> lock(mu_);
150     for (auto it = added_secure_type_names_.begin();
151          it != added_secure_type_names_.end(); it++) {
152       types.push_back(*it);
153     }
154     return types;
155   }
156 
157  private:
158   std::mutex mu_;
159   std::vector<std::string> added_secure_type_names_;
160   std::vector<std::unique_ptr<CredentialTypeProvider>>
161       added_secure_type_providers_;
162   std::string custom_server_key_;
163   std::string custom_server_cert_;
164 };
165 
166 CredentialsProvider* g_provider = nullptr;
167 
168 }  // namespace
169 
GetCredentialsProvider()170 CredentialsProvider* GetCredentialsProvider() {
171   if (g_provider == nullptr) {
172     g_provider = new DefaultCredentialsProvider;
173   }
174   return g_provider;
175 }
176 
SetCredentialsProvider(CredentialsProvider * provider)177 void SetCredentialsProvider(CredentialsProvider* provider) {
178   // For now, forbids overriding provider.
179   GPR_ASSERT(g_provider == nullptr);
180   g_provider = provider;
181 }
182 
183 }  // namespace testing
184 }  // namespace grpc
185