1
2 //
3 //
4 // Copyright 2016 gRPC authors.
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 //
10 // http://www.apache.org/licenses/LICENSE-2.0
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
18 //
19
20 #include "test/cpp/util/test_credentials_provider.h"
21
22 #include <cstdio>
23 #include <fstream>
24 #include <iostream>
25 #include <mutex>
26 #include <unordered_map>
27
28 #include "absl/flags/flag.h"
29
30 #include <grpc/support/log.h>
31 #include <grpc/support/sync.h>
32 #include <grpcpp/security/server_credentials.h>
33
34 #include "src/core/lib/gprpp/crash.h"
35 #include "test/core/end2end/data/ssl_test_data.h"
36
37 ABSL_FLAG(std::string, tls_cert_file, "",
38 "The TLS cert file used when --use_tls=true");
39 ABSL_FLAG(std::string, tls_key_file, "",
40 "The TLS key file used when --use_tls=true");
41
42 namespace grpc {
43 namespace testing {
44 namespace {
45
ReadFile(const std::string & src_path)46 std::string ReadFile(const std::string& src_path) {
47 std::ifstream src;
48 src.open(src_path, std::ifstream::in | std::ifstream::binary);
49
50 std::string contents;
51 src.seekg(0, std::ios::end);
52 contents.reserve(src.tellg());
53 src.seekg(0, std::ios::beg);
54 contents.assign((std::istreambuf_iterator<char>(src)),
55 (std::istreambuf_iterator<char>()));
56 return contents;
57 }
58
59 class DefaultCredentialsProvider : public CredentialsProvider {
60 public:
DefaultCredentialsProvider()61 DefaultCredentialsProvider() {
62 if (!absl::GetFlag(FLAGS_tls_key_file).empty()) {
63 custom_server_key_ = ReadFile(absl::GetFlag(FLAGS_tls_key_file));
64 }
65 if (!absl::GetFlag(FLAGS_tls_cert_file).empty()) {
66 custom_server_cert_ = ReadFile(absl::GetFlag(FLAGS_tls_cert_file));
67 }
68 }
~DefaultCredentialsProvider()69 ~DefaultCredentialsProvider() override {}
70
AddSecureType(const std::string & type,std::unique_ptr<CredentialTypeProvider> type_provider)71 void AddSecureType(
72 const std::string& type,
73 std::unique_ptr<CredentialTypeProvider> type_provider) override {
74 // This clobbers any existing entry for type, except the defaults, which
75 // can't be clobbered.
76 std::unique_lock<std::mutex> lock(mu_);
77 auto it = std::find(added_secure_type_names_.begin(),
78 added_secure_type_names_.end(), type);
79 if (it == added_secure_type_names_.end()) {
80 added_secure_type_names_.push_back(type);
81 added_secure_type_providers_.push_back(std::move(type_provider));
82 } else {
83 added_secure_type_providers_[it - added_secure_type_names_.begin()] =
84 std::move(type_provider);
85 }
86 }
87
GetChannelCredentials(const std::string & type,ChannelArguments * args)88 std::shared_ptr<ChannelCredentials> GetChannelCredentials(
89 const std::string& type, ChannelArguments* args) override {
90 if (type == grpc::testing::kInsecureCredentialsType) {
91 return InsecureChannelCredentials();
92 } else if (type == grpc::testing::kAltsCredentialsType) {
93 grpc::experimental::AltsCredentialsOptions alts_opts;
94 return grpc::experimental::AltsCredentials(alts_opts);
95 } else if (type == grpc::testing::kTlsCredentialsType) {
96 SslCredentialsOptions ssl_opts = {test_root_cert, "", ""};
97 args->SetSslTargetNameOverride("foo.test.google.fr");
98 return grpc::SslCredentials(ssl_opts);
99 } else if (type == grpc::testing::kGoogleDefaultCredentialsType) {
100 return grpc::GoogleDefaultCredentials();
101 } else {
102 std::unique_lock<std::mutex> lock(mu_);
103 auto it(std::find(added_secure_type_names_.begin(),
104 added_secure_type_names_.end(), type));
105 if (it == added_secure_type_names_.end()) {
106 gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
107 return nullptr;
108 }
109 return added_secure_type_providers_[it - added_secure_type_names_.begin()]
110 ->GetChannelCredentials(args);
111 }
112 }
113
GetServerCredentials(const std::string & type)114 std::shared_ptr<ServerCredentials> GetServerCredentials(
115 const std::string& type) override {
116 if (type == grpc::testing::kInsecureCredentialsType) {
117 return InsecureServerCredentials();
118 } else if (type == grpc::testing::kAltsCredentialsType) {
119 grpc::experimental::AltsServerCredentialsOptions alts_opts;
120 return grpc::experimental::AltsServerCredentials(alts_opts);
121 } else if (type == grpc::testing::kTlsCredentialsType) {
122 SslServerCredentialsOptions ssl_opts;
123 ssl_opts.pem_root_certs = "";
124 if (!custom_server_key_.empty() && !custom_server_cert_.empty()) {
125 SslServerCredentialsOptions::PemKeyCertPair pkcp = {
126 custom_server_key_, custom_server_cert_};
127 ssl_opts.pem_key_cert_pairs.push_back(pkcp);
128 } else {
129 SslServerCredentialsOptions::PemKeyCertPair pkcp = {test_server1_key,
130 test_server1_cert};
131 ssl_opts.pem_key_cert_pairs.push_back(pkcp);
132 }
133 return SslServerCredentials(ssl_opts);
134 } else {
135 std::unique_lock<std::mutex> lock(mu_);
136 auto it(std::find(added_secure_type_names_.begin(),
137 added_secure_type_names_.end(), type));
138 if (it == added_secure_type_names_.end()) {
139 gpr_log(GPR_ERROR, "Unsupported credentials type %s.", type.c_str());
140 return nullptr;
141 }
142 return added_secure_type_providers_[it - added_secure_type_names_.begin()]
143 ->GetServerCredentials();
144 }
145 }
GetSecureCredentialsTypeList()146 std::vector<std::string> GetSecureCredentialsTypeList() override {
147 std::vector<std::string> types;
148 types.push_back(grpc::testing::kTlsCredentialsType);
149 std::unique_lock<std::mutex> lock(mu_);
150 for (auto it = added_secure_type_names_.begin();
151 it != added_secure_type_names_.end(); it++) {
152 types.push_back(*it);
153 }
154 return types;
155 }
156
157 private:
158 std::mutex mu_;
159 std::vector<std::string> added_secure_type_names_;
160 std::vector<std::unique_ptr<CredentialTypeProvider>>
161 added_secure_type_providers_;
162 std::string custom_server_key_;
163 std::string custom_server_cert_;
164 };
165
166 CredentialsProvider* g_provider = nullptr;
167
168 } // namespace
169
GetCredentialsProvider()170 CredentialsProvider* GetCredentialsProvider() {
171 if (g_provider == nullptr) {
172 g_provider = new DefaultCredentialsProvider;
173 }
174 return g_provider;
175 }
176
SetCredentialsProvider(CredentialsProvider * provider)177 void SetCredentialsProvider(CredentialsProvider* provider) {
178 // For now, forbids overriding provider.
179 GPR_ASSERT(g_provider == nullptr);
180 g_provider = provider;
181 }
182
183 } // namespace testing
184 } // namespace grpc
185