1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_test_util.h"
6
7 #include <stddef.h>
8
9 #include <algorithm>
10 #include <sstream>
11 #include <utility>
12
13 #include "base/check.h"
14 #include "base/containers/span.h"
15 #include "base/strings/strcat.h"
16 #include "base/strings/string_util.h"
17 #include "base/strings/stringprintf.h"
18 #include "net/base/net_errors.h"
19 #include "net/http/http_network_session.h"
20 #include "net/proxy_resolution/configured_proxy_resolution_service.h"
21 #include "net/proxy_resolution/proxy_resolution_service.h"
22 #include "net/socket/socket_test_util.h"
23 #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
24 #include "net/third_party/quiche/src/quiche/spdy/core/spdy_protocol.h"
25 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
26 #include "net/url_request/url_request_context.h"
27 #include "net/url_request/url_request_context_builder.h"
28 #include "net/websockets/websocket_basic_handshake_stream.h"
29 #include "url/origin.h"
30
31 namespace net {
32 class AuthChallengeInfo;
33 class AuthCredentials;
34 class HttpResponseHeaders;
35 class WebSocketHttp2HandshakeStream;
36 class WebSocketHttp3HandshakeStream;
37
38 namespace {
39
40 const uint64_t kA = (static_cast<uint64_t>(0x5851f42d) << 32) +
41 static_cast<uint64_t>(0x4c957f2d);
42 const uint64_t kC = 12345;
43 const uint64_t kM = static_cast<uint64_t>(1) << 48;
44
45 } // namespace
46
LinearCongruentialGenerator(uint32_t seed)47 LinearCongruentialGenerator::LinearCongruentialGenerator(uint32_t seed)
48 : current_(seed) {}
49
Generate()50 uint32_t LinearCongruentialGenerator::Generate() {
51 uint64_t result = current_;
52 current_ = (current_ * kA + kC) % kM;
53 return static_cast<uint32_t>(result >> 16);
54 }
55
WebSocketExtraHeadersToString(const WebSocketExtraHeaders & headers)56 std::string WebSocketExtraHeadersToString(
57 const WebSocketExtraHeaders& headers) {
58 std::string answer;
59 for (const auto& header : headers) {
60 base::StrAppend(&answer, {header.first, ": ", header.second, "\r\n"});
61 }
62 return answer;
63 }
64
WebSocketExtraHeadersToHttpRequestHeaders(const WebSocketExtraHeaders & headers)65 HttpRequestHeaders WebSocketExtraHeadersToHttpRequestHeaders(
66 const WebSocketExtraHeaders& headers) {
67 HttpRequestHeaders headers_to_return;
68 for (const auto& header : headers)
69 headers_to_return.SetHeader(header.first, header.second);
70 return headers_to_return;
71 }
72
WebSocketStandardRequest(const std::string & path,const std::string & host,const url::Origin & origin,const WebSocketExtraHeaders & send_additional_request_headers,const WebSocketExtraHeaders & extra_headers)73 std::string WebSocketStandardRequest(
74 const std::string& path,
75 const std::string& host,
76 const url::Origin& origin,
77 const WebSocketExtraHeaders& send_additional_request_headers,
78 const WebSocketExtraHeaders& extra_headers) {
79 return WebSocketStandardRequestWithCookies(path, host, origin, /*cookies=*/{},
80 send_additional_request_headers,
81 extra_headers);
82 }
83
WebSocketStandardRequestWithCookies(const std::string & path,const std::string & host,const url::Origin & origin,const WebSocketExtraHeaders & cookies,const WebSocketExtraHeaders & send_additional_request_headers,const WebSocketExtraHeaders & extra_headers)84 std::string WebSocketStandardRequestWithCookies(
85 const std::string& path,
86 const std::string& host,
87 const url::Origin& origin,
88 const WebSocketExtraHeaders& cookies,
89 const WebSocketExtraHeaders& send_additional_request_headers,
90 const WebSocketExtraHeaders& extra_headers) {
91 // Unrelated changes in net/http may change the order and default-values of
92 // HTTP headers, causing WebSocket tests to fail. It is safe to update this
93 // in that case.
94 HttpRequestHeaders headers;
95 std::stringstream request_headers;
96
97 request_headers << base::StringPrintf("GET %s HTTP/1.1\r\n", path.c_str());
98 headers.SetHeader("Host", host);
99 headers.SetHeader("Connection", "Upgrade");
100 headers.SetHeader("Pragma", "no-cache");
101 headers.SetHeader("Cache-Control", "no-cache");
102 for (const auto& [key, value] : send_additional_request_headers)
103 headers.SetHeader(key, value);
104 headers.SetHeader("Upgrade", "websocket");
105 headers.SetHeader("Origin", origin.Serialize());
106 headers.SetHeader("Sec-WebSocket-Version", "13");
107 if (!headers.HasHeader("User-Agent"))
108 headers.SetHeader("User-Agent", "");
109 headers.SetHeader("Accept-Encoding", "gzip, deflate");
110 headers.SetHeader("Accept-Language", "en-us,fr");
111 for (const auto& [key, value] : cookies)
112 headers.SetHeader(key, value);
113 headers.SetHeader("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==");
114 headers.SetHeader("Sec-WebSocket-Extensions",
115 "permessage-deflate; client_max_window_bits");
116 for (const auto& [key, value] : extra_headers)
117 headers.SetHeader(key, value);
118
119 request_headers << headers.ToString();
120 return request_headers.str();
121 }
122
WebSocketStandardResponse(const std::string & extra_headers)123 std::string WebSocketStandardResponse(const std::string& extra_headers) {
124 return base::StrCat(
125 {"HTTP/1.1 101 Switching Protocols\r\n"
126 "Upgrade: websocket\r\n"
127 "Connection: Upgrade\r\n"
128 "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n",
129 extra_headers, "\r\n"});
130 }
131
WebSocketCommonTestHeaders()132 HttpRequestHeaders WebSocketCommonTestHeaders() {
133 HttpRequestHeaders request_headers;
134 request_headers.SetHeader("Host", "www.example.org");
135 request_headers.SetHeader("Connection", "Upgrade");
136 request_headers.SetHeader("Pragma", "no-cache");
137 request_headers.SetHeader("Cache-Control", "no-cache");
138 request_headers.SetHeader("Upgrade", "websocket");
139 request_headers.SetHeader("Origin", "http://origin.example.org");
140 request_headers.SetHeader("Sec-WebSocket-Version", "13");
141 request_headers.SetHeader("User-Agent", "");
142 request_headers.SetHeader("Accept-Encoding", "gzip, deflate");
143 request_headers.SetHeader("Accept-Language", "en-us,fr");
144 return request_headers;
145 }
146
WebSocketHttp2Request(const std::string & path,const std::string & authority,const std::string & origin,const WebSocketExtraHeaders & extra_headers)147 spdy::Http2HeaderBlock WebSocketHttp2Request(
148 const std::string& path,
149 const std::string& authority,
150 const std::string& origin,
151 const WebSocketExtraHeaders& extra_headers) {
152 spdy::Http2HeaderBlock request_headers;
153 request_headers[spdy::kHttp2MethodHeader] = "CONNECT";
154 request_headers[spdy::kHttp2AuthorityHeader] = authority;
155 request_headers[spdy::kHttp2SchemeHeader] = "https";
156 request_headers[spdy::kHttp2PathHeader] = path;
157 request_headers[spdy::kHttp2ProtocolHeader] = "websocket";
158 request_headers["pragma"] = "no-cache";
159 request_headers["cache-control"] = "no-cache";
160 request_headers["origin"] = origin;
161 request_headers["sec-websocket-version"] = "13";
162 request_headers["user-agent"] = "";
163 request_headers["accept-encoding"] = "gzip, deflate";
164 request_headers["accept-language"] = "en-us,fr";
165 request_headers["sec-websocket-extensions"] =
166 "permessage-deflate; client_max_window_bits";
167 for (const auto& header : extra_headers) {
168 request_headers[base::ToLowerASCII(header.first)] = header.second;
169 }
170 return request_headers;
171 }
172
WebSocketHttp2Response(const WebSocketExtraHeaders & extra_headers)173 spdy::Http2HeaderBlock WebSocketHttp2Response(
174 const WebSocketExtraHeaders& extra_headers) {
175 spdy::Http2HeaderBlock response_headers;
176 response_headers[spdy::kHttp2StatusHeader] = "200";
177 for (const auto& header : extra_headers) {
178 response_headers[base::ToLowerASCII(header.first)] = header.second;
179 }
180 return response_headers;
181 }
182
183 struct WebSocketMockClientSocketFactoryMaker::Detail {
184 std::string expect_written;
185 std::string return_to_read;
186 std::vector<MockRead> reads;
187 MockWrite write;
188 std::vector<std::unique_ptr<SequencedSocketData>> socket_data_vector;
189 std::vector<std::unique_ptr<SSLSocketDataProvider>> ssl_socket_data_vector;
190 MockClientSocketFactory factory;
191 };
192
WebSocketMockClientSocketFactoryMaker()193 WebSocketMockClientSocketFactoryMaker::WebSocketMockClientSocketFactoryMaker()
194 : detail_(std::make_unique<Detail>()) {}
195
196 WebSocketMockClientSocketFactoryMaker::
197 ~WebSocketMockClientSocketFactoryMaker() = default;
198
factory()199 MockClientSocketFactory* WebSocketMockClientSocketFactoryMaker::factory() {
200 return &detail_->factory;
201 }
202
SetExpectations(const std::string & expect_written,const std::string & return_to_read)203 void WebSocketMockClientSocketFactoryMaker::SetExpectations(
204 const std::string& expect_written,
205 const std::string& return_to_read) {
206 constexpr size_t kHttpStreamParserBufferSize = 4096;
207 // We need to extend the lifetime of these strings.
208 detail_->expect_written = expect_written;
209 detail_->return_to_read = return_to_read;
210 int sequence = 0;
211 detail_->write = MockWrite(SYNCHRONOUS,
212 detail_->expect_written.data(),
213 detail_->expect_written.size(),
214 sequence++);
215 // HttpStreamParser reads 4KB at a time. We need to take this implementation
216 // detail into account if |return_to_read| is big enough.
217 for (size_t place = 0; place < detail_->return_to_read.size();
218 place += kHttpStreamParserBufferSize) {
219 detail_->reads.emplace_back(SYNCHRONOUS,
220 detail_->return_to_read.data() + place,
221 std::min(detail_->return_to_read.size() - place,
222 kHttpStreamParserBufferSize),
223 sequence++);
224 }
225 auto socket_data = std::make_unique<SequencedSocketData>(
226 detail_->reads, base::make_span(&detail_->write, 1u));
227 socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK));
228 AddRawExpectations(std::move(socket_data));
229 }
230
AddRawExpectations(std::unique_ptr<SequencedSocketData> socket_data)231 void WebSocketMockClientSocketFactoryMaker::AddRawExpectations(
232 std::unique_ptr<SequencedSocketData> socket_data) {
233 detail_->factory.AddSocketDataProvider(socket_data.get());
234 detail_->socket_data_vector.push_back(std::move(socket_data));
235 }
236
AddSSLSocketDataProvider(std::unique_ptr<SSLSocketDataProvider> ssl_socket_data)237 void WebSocketMockClientSocketFactoryMaker::AddSSLSocketDataProvider(
238 std::unique_ptr<SSLSocketDataProvider> ssl_socket_data) {
239 detail_->factory.AddSSLSocketDataProvider(ssl_socket_data.get());
240 detail_->ssl_socket_data_vector.push_back(std::move(ssl_socket_data));
241 }
242
WebSocketTestURLRequestContextHost()243 WebSocketTestURLRequestContextHost::WebSocketTestURLRequestContextHost()
244 : url_request_context_builder_(CreateTestURLRequestContextBuilder()) {
245 url_request_context_builder_->set_client_socket_factory_for_testing(
246 maker_.factory());
247 HttpNetworkSessionParams params;
248 params.enable_spdy_ping_based_connection_checking = false;
249 params.enable_quic = false;
250 params.disable_idle_sockets_close_on_memory_pressure = false;
251 url_request_context_builder_->set_http_network_session_params(params);
252 }
253
254 WebSocketTestURLRequestContextHost::~WebSocketTestURLRequestContextHost() =
255 default;
256
AddRawExpectations(std::unique_ptr<SequencedSocketData> socket_data)257 void WebSocketTestURLRequestContextHost::AddRawExpectations(
258 std::unique_ptr<SequencedSocketData> socket_data) {
259 maker_.AddRawExpectations(std::move(socket_data));
260 }
261
AddSSLSocketDataProvider(std::unique_ptr<SSLSocketDataProvider> ssl_socket_data)262 void WebSocketTestURLRequestContextHost::AddSSLSocketDataProvider(
263 std::unique_ptr<SSLSocketDataProvider> ssl_socket_data) {
264 maker_.AddSSLSocketDataProvider(std::move(ssl_socket_data));
265 }
266
SetProxyConfig(const std::string & proxy_rules)267 void WebSocketTestURLRequestContextHost::SetProxyConfig(
268 const std::string& proxy_rules) {
269 DCHECK(!url_request_context_);
270 auto proxy_resolution_service =
271 ConfiguredProxyResolutionService::CreateFixedForTest(
272 proxy_rules, TRAFFIC_ANNOTATION_FOR_TESTS);
273 url_request_context_builder_->set_proxy_resolution_service(
274 std::move(proxy_resolution_service));
275 }
276
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)277 void DummyConnectDelegate::OnURLRequestConnected(URLRequest* request,
278 const TransportInfo& info) {}
279
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & host_port_pair,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)280 int DummyConnectDelegate::OnAuthRequired(
281 const AuthChallengeInfo& auth_info,
282 scoped_refptr<HttpResponseHeaders> response_headers,
283 const IPEndPoint& host_port_pair,
284 base::OnceCallback<void(const AuthCredentials*)> callback,
285 std::optional<AuthCredentials>* credentials) {
286 return OK;
287 }
288
GetURLRequestContext()289 URLRequestContext* WebSocketTestURLRequestContextHost::GetURLRequestContext() {
290 if (!url_request_context_) {
291 url_request_context_builder_->set_network_delegate(
292 std::make_unique<TestNetworkDelegate>());
293 url_request_context_ = url_request_context_builder_->Build();
294 url_request_context_builder_ = nullptr;
295 }
296 return url_request_context_.get();
297 }
298
OnBasicHandshakeStreamCreated(WebSocketBasicHandshakeStream * handshake_stream)299 void TestWebSocketStreamRequestAPI::OnBasicHandshakeStreamCreated(
300 WebSocketBasicHandshakeStream* handshake_stream) {
301 handshake_stream->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
302 }
303
OnHttp2HandshakeStreamCreated(WebSocketHttp2HandshakeStream * handshake_stream)304 void TestWebSocketStreamRequestAPI::OnHttp2HandshakeStreamCreated(
305 WebSocketHttp2HandshakeStream* handshake_stream) {}
306
OnHttp3HandshakeStreamCreated(WebSocketHttp3HandshakeStream * handshake_stream)307 void TestWebSocketStreamRequestAPI::OnHttp3HandshakeStreamCreated(
308 WebSocketHttp3HandshakeStream* handshake_stream) {}
309 } // namespace net
310