1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/cert/mock_cert_verifier.h"
6
7 #include <memory>
8 #include <utility>
9
10 #include "base/callback_list.h"
11 #include "base/functional/bind.h"
12 #include "base/location.h"
13 #include "base/memory/raw_ptr.h"
14 #include "base/memory/ref_counted.h"
15 #include "base/memory/weak_ptr.h"
16 #include "base/strings/pattern.h"
17 #include "base/strings/string_util.h"
18 #include "base/task/single_thread_task_runner.h"
19 #include "net/base/net_errors.h"
20 #include "net/cert/cert_status_flags.h"
21 #include "net/cert/cert_verify_result.h"
22 #include "net/cert/x509_certificate.h"
23
24 namespace net {
25
26 namespace {
27 // Helper function for setting the appropriate CertStatus given a net::Error.
MapNetErrorToCertStatus(int error)28 CertStatus MapNetErrorToCertStatus(int error) {
29 switch (error) {
30 case ERR_CERT_COMMON_NAME_INVALID:
31 return CERT_STATUS_COMMON_NAME_INVALID;
32 case ERR_CERT_DATE_INVALID:
33 return CERT_STATUS_DATE_INVALID;
34 case ERR_CERT_AUTHORITY_INVALID:
35 return CERT_STATUS_AUTHORITY_INVALID;
36 case ERR_CERT_NO_REVOCATION_MECHANISM:
37 return CERT_STATUS_NO_REVOCATION_MECHANISM;
38 case ERR_CERT_UNABLE_TO_CHECK_REVOCATION:
39 return CERT_STATUS_UNABLE_TO_CHECK_REVOCATION;
40 case ERR_CERTIFICATE_TRANSPARENCY_REQUIRED:
41 return CERT_STATUS_CERTIFICATE_TRANSPARENCY_REQUIRED;
42 case ERR_CERT_REVOKED:
43 return CERT_STATUS_REVOKED;
44 case ERR_CERT_INVALID:
45 return CERT_STATUS_INVALID;
46 case ERR_CERT_WEAK_SIGNATURE_ALGORITHM:
47 return CERT_STATUS_WEAK_SIGNATURE_ALGORITHM;
48 case ERR_CERT_NON_UNIQUE_NAME:
49 return CERT_STATUS_NON_UNIQUE_NAME;
50 case ERR_CERT_WEAK_KEY:
51 return CERT_STATUS_WEAK_KEY;
52 case ERR_SSL_PINNED_KEY_NOT_IN_CERT_CHAIN:
53 return CERT_STATUS_PINNED_KEY_MISSING;
54 case ERR_CERT_NAME_CONSTRAINT_VIOLATION:
55 return CERT_STATUS_NAME_CONSTRAINT_VIOLATION;
56 case ERR_CERT_VALIDITY_TOO_LONG:
57 return CERT_STATUS_VALIDITY_TOO_LONG;
58 case ERR_CERT_SYMANTEC_LEGACY:
59 return CERT_STATUS_SYMANTEC_LEGACY;
60 case ERR_CERT_KNOWN_INTERCEPTION_BLOCKED:
61 return (CERT_STATUS_KNOWN_INTERCEPTION_BLOCKED | CERT_STATUS_REVOKED);
62 default:
63 return 0;
64 }
65 }
66 } // namespace
67
68 struct MockCertVerifier::Rule {
Rulenet::MockCertVerifier::Rule69 Rule(scoped_refptr<X509Certificate> cert_arg,
70 const std::string& hostname_arg,
71 const CertVerifyResult& result_arg,
72 int rv_arg)
73 : cert(std::move(cert_arg)),
74 hostname(hostname_arg),
75 result(result_arg),
76 rv(rv_arg) {
77 DCHECK(cert);
78 DCHECK(result.verified_cert);
79 }
80
81 scoped_refptr<X509Certificate> cert;
82 std::string hostname;
83 CertVerifyResult result;
84 int rv;
85 };
86
87 class MockCertVerifier::MockRequest : public CertVerifier::Request {
88 public:
MockRequest(MockCertVerifier * parent,CertVerifyResult * result,CompletionOnceCallback callback)89 MockRequest(MockCertVerifier* parent,
90 CertVerifyResult* result,
91 CompletionOnceCallback callback)
92 : result_(result), callback_(std::move(callback)) {
93 subscription_ = parent->request_list_.Add(
94 base::BindOnce(&MockRequest::Cleanup, weak_factory_.GetWeakPtr()));
95 }
96
ReturnResultLater(int rv,const CertVerifyResult & result)97 void ReturnResultLater(int rv, const CertVerifyResult& result) {
98 base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
99 FROM_HERE, base::BindOnce(&MockRequest::ReturnResult,
100 weak_factory_.GetWeakPtr(), rv, result));
101 }
102
103 private:
ReturnResult(int rv,const CertVerifyResult & result)104 void ReturnResult(int rv, const CertVerifyResult& result) {
105 // If the MockCertVerifier has been deleted, the callback will have been
106 // reset to null.
107 if (!callback_)
108 return;
109
110 *result_ = result;
111 std::move(callback_).Run(rv);
112 }
113
Cleanup()114 void Cleanup() {
115 // Note: May delete |this_|.
116 std::move(callback_).Reset();
117 }
118
119 raw_ptr<CertVerifyResult> result_;
120 CompletionOnceCallback callback_;
121 base::CallbackListSubscription subscription_;
122
123 base::WeakPtrFactory<MockRequest> weak_factory_{this};
124 };
125
126 MockCertVerifier::MockCertVerifier() = default;
127
~MockCertVerifier()128 MockCertVerifier::~MockCertVerifier() {
129 // Reset the callbacks for any outstanding MockRequests to fulfill the
130 // respective net::CertVerifier contract.
131 request_list_.Notify();
132 }
133
Verify(const RequestParams & params,CertVerifyResult * verify_result,CompletionOnceCallback callback,std::unique_ptr<Request> * out_req,const NetLogWithSource & net_log)134 int MockCertVerifier::Verify(const RequestParams& params,
135 CertVerifyResult* verify_result,
136 CompletionOnceCallback callback,
137 std::unique_ptr<Request>* out_req,
138 const NetLogWithSource& net_log) {
139 if (!async_) {
140 return VerifyImpl(params, verify_result);
141 }
142
143 auto request =
144 std::make_unique<MockRequest>(this, verify_result, std::move(callback));
145 CertVerifyResult result;
146 int rv = VerifyImpl(params, &result);
147 request->ReturnResultLater(rv, result);
148 *out_req = std::move(request);
149 return ERR_IO_PENDING;
150 }
151
AddObserver(Observer * observer)152 void MockCertVerifier::AddObserver(Observer* observer) {
153 observers_.AddObserver(observer);
154 }
155
RemoveObserver(Observer * observer)156 void MockCertVerifier::RemoveObserver(Observer* observer) {
157 observers_.RemoveObserver(observer);
158 }
159
AddResultForCert(scoped_refptr<X509Certificate> cert,const CertVerifyResult & verify_result,int rv)160 void MockCertVerifier::AddResultForCert(scoped_refptr<X509Certificate> cert,
161 const CertVerifyResult& verify_result,
162 int rv) {
163 AddResultForCertAndHost(std::move(cert), "*", verify_result, rv);
164 }
165
AddResultForCertAndHost(scoped_refptr<X509Certificate> cert,const std::string & host_pattern,const CertVerifyResult & verify_result,int rv)166 void MockCertVerifier::AddResultForCertAndHost(
167 scoped_refptr<X509Certificate> cert,
168 const std::string& host_pattern,
169 const CertVerifyResult& verify_result,
170 int rv) {
171 rules_.push_back(Rule(std::move(cert), host_pattern, verify_result, rv));
172 }
173
ClearRules()174 void MockCertVerifier::ClearRules() {
175 rules_.clear();
176 }
177
SimulateOnCertVerifierChanged()178 void MockCertVerifier::SimulateOnCertVerifierChanged() {
179 for (Observer& observer : observers_) {
180 observer.OnCertVerifierChanged();
181 }
182 }
183
VerifyImpl(const RequestParams & params,CertVerifyResult * verify_result)184 int MockCertVerifier::VerifyImpl(const RequestParams& params,
185 CertVerifyResult* verify_result) {
186 for (const Rule& rule : rules_) {
187 // Check just the server cert. Intermediates will be ignored.
188 if (!rule.cert->EqualsExcludingChain(params.certificate().get()))
189 continue;
190 if (!base::MatchPattern(params.hostname(), rule.hostname))
191 continue;
192 *verify_result = rule.result;
193 return rule.rv;
194 }
195
196 // Fall through to the default.
197 verify_result->verified_cert = params.certificate();
198 verify_result->cert_status = MapNetErrorToCertStatus(default_result_);
199 return default_result_;
200 }
201
202 ParamRecordingMockCertVerifier::ParamRecordingMockCertVerifier() = default;
203 ParamRecordingMockCertVerifier::~ParamRecordingMockCertVerifier() = default;
204
Verify(const RequestParams & params,CertVerifyResult * verify_result,CompletionOnceCallback callback,std::unique_ptr<Request> * out_req,const NetLogWithSource & net_log)205 int ParamRecordingMockCertVerifier::Verify(const RequestParams& params,
206 CertVerifyResult* verify_result,
207 CompletionOnceCallback callback,
208 std::unique_ptr<Request>* out_req,
209 const NetLogWithSource& net_log) {
210 params_.push_back(params);
211 return MockCertVerifier::Verify(params, verify_result, std::move(callback),
212 out_req, net_log);
213 }
214
CertVerifierObserverCounter(CertVerifier * verifier)215 CertVerifierObserverCounter::CertVerifierObserverCounter(
216 CertVerifier* verifier) {
217 obs_.Observe(verifier);
218 }
219
220 CertVerifierObserverCounter::~CertVerifierObserverCounter() = default;
221
OnCertVerifierChanged()222 void CertVerifierObserverCounter::OnCertVerifierChanged() {
223 change_count_++;
224 }
225
226 } // namespace net
227