xref: /aosp_15_r20/external/cronet/net/websockets/websocket_http2_handshake_stream.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_http2_handshake_stream.h"
6 
7 #include <set>
8 #include <string_view>
9 #include <utility>
10 
11 #include "base/check.h"
12 #include "base/check_op.h"
13 #include "base/functional/bind.h"
14 #include "base/functional/callback.h"
15 #include "base/memory/scoped_refptr.h"
16 #include "base/notreached.h"
17 #include "base/strings/strcat.h"
18 #include "base/strings/stringprintf.h"
19 #include "base/time/time.h"
20 #include "net/base/ip_endpoint.h"
21 #include "net/http/http_request_headers.h"
22 #include "net/http/http_request_info.h"
23 #include "net/http/http_response_headers.h"
24 #include "net/http/http_response_info.h"
25 #include "net/http/http_status_code.h"
26 #include "net/spdy/spdy_http_utils.h"
27 #include "net/spdy/spdy_session.h"
28 #include "net/spdy/spdy_stream.h"
29 #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
30 #include "net/traffic_annotation/network_traffic_annotation.h"
31 #include "net/websockets/websocket_basic_stream.h"
32 #include "net/websockets/websocket_deflate_predictor_impl.h"
33 #include "net/websockets/websocket_deflate_stream.h"
34 #include "net/websockets/websocket_handshake_constants.h"
35 #include "net/websockets/websocket_handshake_request_info.h"
36 
37 namespace net {
38 
39 namespace {
40 
ValidateStatus(const HttpResponseHeaders * headers)41 bool ValidateStatus(const HttpResponseHeaders* headers) {
42   return headers->GetStatusLine() == "HTTP/1.1 200";
43 }
44 
45 }  // namespace
46 
WebSocketHttp2HandshakeStream(base::WeakPtr<SpdySession> session,WebSocketStream::ConnectDelegate * connect_delegate,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,WebSocketStreamRequestAPI * request,std::set<std::string> dns_aliases)47 WebSocketHttp2HandshakeStream::WebSocketHttp2HandshakeStream(
48     base::WeakPtr<SpdySession> session,
49     WebSocketStream::ConnectDelegate* connect_delegate,
50     std::vector<std::string> requested_sub_protocols,
51     std::vector<std::string> requested_extensions,
52     WebSocketStreamRequestAPI* request,
53     std::set<std::string> dns_aliases)
54     : session_(session),
55       connect_delegate_(connect_delegate),
56       requested_sub_protocols_(requested_sub_protocols),
57       requested_extensions_(requested_extensions),
58       stream_request_(request),
59       dns_aliases_(std::move(dns_aliases)) {
60   DCHECK(connect_delegate);
61   DCHECK(request);
62 }
63 
~WebSocketHttp2HandshakeStream()64 WebSocketHttp2HandshakeStream::~WebSocketHttp2HandshakeStream() {
65   spdy_stream_request_.reset();
66   RecordHandshakeResult(result_);
67 }
68 
RegisterRequest(const HttpRequestInfo * request_info)69 void WebSocketHttp2HandshakeStream::RegisterRequest(
70     const HttpRequestInfo* request_info) {
71   DCHECK(request_info);
72   DCHECK(request_info->traffic_annotation.is_valid());
73   request_info_ = request_info;
74 }
75 
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)76 int WebSocketHttp2HandshakeStream::InitializeStream(
77     bool can_send_early,
78     RequestPriority priority,
79     const NetLogWithSource& net_log,
80     CompletionOnceCallback callback) {
81   priority_ = priority;
82   net_log_ = net_log;
83   return OK;
84 }
85 
SendRequest(const HttpRequestHeaders & headers,HttpResponseInfo * response,CompletionOnceCallback callback)86 int WebSocketHttp2HandshakeStream::SendRequest(
87     const HttpRequestHeaders& headers,
88     HttpResponseInfo* response,
89     CompletionOnceCallback callback) {
90   DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
91   DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
92   DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
93   DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
94   DCHECK(headers.HasHeader(websockets::kUpgrade));
95   DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
96   DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
97 
98   if (!session_) {
99     const int rv = ERR_CONNECTION_CLOSED;
100     OnFailure("Connection closed before sending request.", rv, std::nullopt);
101     return rv;
102   }
103 
104   http_response_info_ = response;
105 
106   IPEndPoint address;
107   int result = session_->GetPeerAddress(&address);
108   if (result != OK) {
109     OnFailure("Error getting IP address.", result, std::nullopt);
110     return result;
111   }
112   http_response_info_->remote_endpoint = address;
113 
114   auto request = std::make_unique<WebSocketHandshakeRequestInfo>(
115       request_info_->url, base::Time::Now());
116   request->headers = headers;
117 
118   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
119                             requested_extensions_, &request->headers);
120   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
121                             requested_sub_protocols_, &request->headers);
122 
123   CreateSpdyHeadersFromHttpRequestForWebSocket(
124       request_info_->url, request->headers, &http2_request_headers_);
125 
126   connect_delegate_->OnStartOpeningHandshake(std::move(request));
127 
128   callback_ = std::move(callback);
129   spdy_stream_request_ = std::make_unique<SpdyStreamRequest>();
130   // The initial request for the WebSocket is a CONNECT, so there is no need to
131   // call ConfirmHandshake().
132   int rv = spdy_stream_request_->StartRequest(
133       SPDY_BIDIRECTIONAL_STREAM, session_, request_info_->url, true, priority_,
134       request_info_->socket_tag, net_log_,
135       base::BindOnce(&WebSocketHttp2HandshakeStream::StartRequestCallback,
136                      base::Unretained(this)),
137       NetworkTrafficAnnotationTag(request_info_->traffic_annotation));
138   if (rv == OK) {
139     StartRequestCallback(rv);
140     return ERR_IO_PENDING;
141   }
142   return rv;
143 }
144 
ReadResponseHeaders(CompletionOnceCallback callback)145 int WebSocketHttp2HandshakeStream::ReadResponseHeaders(
146     CompletionOnceCallback callback) {
147   if (stream_closed_)
148     return stream_error_;
149 
150   if (response_headers_complete_)
151     return ValidateResponse();
152 
153   callback_ = std::move(callback);
154   return ERR_IO_PENDING;
155 }
156 
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)157 int WebSocketHttp2HandshakeStream::ReadResponseBody(
158     IOBuffer* buf,
159     int buf_len,
160     CompletionOnceCallback callback) {
161   // Callers should instead call Upgrade() to get a WebSocketStream
162   // and call ReadFrames() on that.
163   NOTREACHED();
164   return OK;
165 }
166 
Close(bool not_reusable)167 void WebSocketHttp2HandshakeStream::Close(bool not_reusable) {
168   spdy_stream_request_.reset();
169   if (stream_) {
170     stream_ = nullptr;
171     stream_closed_ = true;
172     stream_error_ = ERR_CONNECTION_CLOSED;
173   }
174   stream_adapter_.reset();
175 }
176 
IsResponseBodyComplete() const177 bool WebSocketHttp2HandshakeStream::IsResponseBodyComplete() const {
178   return false;
179 }
180 
IsConnectionReused() const181 bool WebSocketHttp2HandshakeStream::IsConnectionReused() const {
182   return true;
183 }
184 
SetConnectionReused()185 void WebSocketHttp2HandshakeStream::SetConnectionReused() {}
186 
CanReuseConnection() const187 bool WebSocketHttp2HandshakeStream::CanReuseConnection() const {
188   return false;
189 }
190 
GetTotalReceivedBytes() const191 int64_t WebSocketHttp2HandshakeStream::GetTotalReceivedBytes() const {
192   return stream_ ? stream_->raw_received_bytes() : 0;
193 }
194 
GetTotalSentBytes() const195 int64_t WebSocketHttp2HandshakeStream::GetTotalSentBytes() const {
196   return stream_ ? stream_->raw_sent_bytes() : 0;
197 }
198 
GetAlternativeService(AlternativeService * alternative_service) const199 bool WebSocketHttp2HandshakeStream::GetAlternativeService(
200     AlternativeService* alternative_service) const {
201   return false;
202 }
203 
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const204 bool WebSocketHttp2HandshakeStream::GetLoadTimingInfo(
205     LoadTimingInfo* load_timing_info) const {
206   return stream_ && stream_->GetLoadTimingInfo(load_timing_info);
207 }
208 
GetSSLInfo(SSLInfo * ssl_info)209 void WebSocketHttp2HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
210   if (stream_)
211     stream_->GetSSLInfo(ssl_info);
212 }
213 
GetRemoteEndpoint(IPEndPoint * endpoint)214 int WebSocketHttp2HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) {
215   if (!session_)
216     return ERR_SOCKET_NOT_CONNECTED;
217 
218   return session_->GetRemoteEndpoint(endpoint);
219 }
220 
PopulateNetErrorDetails(NetErrorDetails *)221 void WebSocketHttp2HandshakeStream::PopulateNetErrorDetails(
222     NetErrorDetails* /*details*/) {
223   return;
224 }
225 
Drain(HttpNetworkSession * session)226 void WebSocketHttp2HandshakeStream::Drain(HttpNetworkSession* session) {
227   Close(true /* not_reusable */);
228 }
229 
SetPriority(RequestPriority priority)230 void WebSocketHttp2HandshakeStream::SetPriority(RequestPriority priority) {
231   priority_ = priority;
232   if (stream_)
233     stream_->SetPriority(priority_);
234 }
235 
236 std::unique_ptr<HttpStream>
RenewStreamForAuth()237 WebSocketHttp2HandshakeStream::RenewStreamForAuth() {
238   // Renewing the stream is not supported.
239   return nullptr;
240 }
241 
GetDnsAliases() const242 const std::set<std::string>& WebSocketHttp2HandshakeStream::GetDnsAliases()
243     const {
244   return dns_aliases_;
245 }
246 
GetAcceptChViaAlps() const247 std::string_view WebSocketHttp2HandshakeStream::GetAcceptChViaAlps() const {
248   return {};
249 }
250 
Upgrade()251 std::unique_ptr<WebSocketStream> WebSocketHttp2HandshakeStream::Upgrade() {
252   DCHECK(extension_params_.get());
253 
254   stream_adapter_->DetachDelegate();
255   std::unique_ptr<WebSocketStream> basic_stream =
256       std::make_unique<WebSocketBasicStream>(std::move(stream_adapter_),
257                                              nullptr, sub_protocol_,
258                                              extensions_, net_log_);
259 
260   if (!extension_params_->deflate_enabled)
261     return basic_stream;
262 
263   return std::make_unique<WebSocketDeflateStream>(
264       std::move(basic_stream), extension_params_->deflate_parameters,
265       std::make_unique<WebSocketDeflatePredictorImpl>());
266 }
267 
CanReadFromStream() const268 bool WebSocketHttp2HandshakeStream::CanReadFromStream() const {
269   return stream_adapter_ && stream_adapter_->is_initialized();
270 }
271 
272 base::WeakPtr<WebSocketHandshakeStreamBase>
GetWeakPtr()273 WebSocketHttp2HandshakeStream::GetWeakPtr() {
274   return weak_ptr_factory_.GetWeakPtr();
275 }
276 
OnHeadersSent()277 void WebSocketHttp2HandshakeStream::OnHeadersSent() {
278   std::move(callback_).Run(OK);
279 }
280 
OnHeadersReceived(const spdy::Http2HeaderBlock & response_headers)281 void WebSocketHttp2HandshakeStream::OnHeadersReceived(
282     const spdy::Http2HeaderBlock& response_headers) {
283   DCHECK(!response_headers_complete_);
284   DCHECK(http_response_info_);
285 
286   response_headers_complete_ = true;
287 
288   const int rv =
289       SpdyHeadersToHttpResponse(response_headers, http_response_info_);
290   DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS);
291 
292   http_response_info_->response_time = stream_->response_time();
293   // Do not store SSLInfo in the response here, HttpNetworkTransaction will take
294   // care of that part.
295   http_response_info_->was_alpn_negotiated = true;
296   http_response_info_->request_time = stream_->GetRequestTime();
297   http_response_info_->connection_info = HttpConnectionInfo::kHTTP2;
298   http_response_info_->alpn_negotiated_protocol =
299       HttpConnectionInfoToString(http_response_info_->connection_info);
300 
301   if (callback_)
302     std::move(callback_).Run(ValidateResponse());
303 }
304 
OnClose(int status)305 void WebSocketHttp2HandshakeStream::OnClose(int status) {
306   DCHECK(stream_adapter_);
307   DCHECK_GT(ERR_IO_PENDING, status);
308 
309   stream_closed_ = true;
310   stream_error_ = status;
311   stream_ = nullptr;
312 
313   stream_adapter_.reset();
314 
315   // If response headers have already been received,
316   // then ValidateResponse() sets |result_|.
317   if (!response_headers_complete_)
318     result_ = HandshakeResult::HTTP2_FAILED;
319 
320   OnFailure(base::StrCat({"Stream closed with error: ", ErrorToString(status)}),
321             status, std::nullopt);
322 
323   if (callback_)
324     std::move(callback_).Run(status);
325 }
326 
StartRequestCallback(int rv)327 void WebSocketHttp2HandshakeStream::StartRequestCallback(int rv) {
328   DCHECK(callback_);
329   if (rv != OK) {
330     spdy_stream_request_.reset();
331     std::move(callback_).Run(rv);
332     return;
333   }
334   stream_ = spdy_stream_request_->ReleaseStream();
335   spdy_stream_request_.reset();
336   stream_adapter_ =
337       std::make_unique<WebSocketSpdyStreamAdapter>(stream_, this, net_log_);
338   rv = stream_->SendRequestHeaders(std::move(http2_request_headers_),
339                                    MORE_DATA_TO_SEND);
340   // SendRequestHeaders() always returns asynchronously,
341   // and instead of taking a callback, it calls OnHeadersSent().
342   DCHECK_EQ(ERR_IO_PENDING, rv);
343 }
344 
ValidateResponse()345 int WebSocketHttp2HandshakeStream::ValidateResponse() {
346   DCHECK(http_response_info_);
347   const HttpResponseHeaders* headers = http_response_info_->headers.get();
348   const int response_code = headers->response_code();
349   switch (response_code) {
350     case HTTP_OK:
351       return ValidateUpgradeResponse(headers);
352 
353     // We need to pass these through for authentication to work.
354     case HTTP_UNAUTHORIZED:
355     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
356       return OK;
357 
358     // Other status codes are potentially risky (see the warnings in the
359     // WHATWG WebSocket API spec) and so are dropped by default.
360     default:
361       OnFailure(
362           base::StringPrintf(
363               "Error during WebSocket handshake: Unexpected response code: %d",
364               headers->response_code()),
365           ERR_FAILED, headers->response_code());
366       result_ = HandshakeResult::HTTP2_INVALID_STATUS;
367       return ERR_INVALID_RESPONSE;
368   }
369 }
370 
ValidateUpgradeResponse(const HttpResponseHeaders * headers)371 int WebSocketHttp2HandshakeStream::ValidateUpgradeResponse(
372     const HttpResponseHeaders* headers) {
373   extension_params_ = std::make_unique<WebSocketExtensionParams>();
374   std::string failure_message;
375   if (!ValidateStatus(headers)) {
376     result_ = HandshakeResult::HTTP2_INVALID_STATUS;
377   } else if (!ValidateSubProtocol(headers, requested_sub_protocols_,
378                                   &sub_protocol_, &failure_message)) {
379     result_ = HandshakeResult::HTTP2_FAILED_SUBPROTO;
380   } else if (!ValidateExtensions(headers, &extensions_, &failure_message,
381                                  extension_params_.get())) {
382     result_ = HandshakeResult::HTTP2_FAILED_EXTENSIONS;
383   } else {
384     result_ = HandshakeResult::HTTP2_CONNECTED;
385     return OK;
386   }
387 
388   const int rv = ERR_INVALID_RESPONSE;
389   OnFailure("Error during WebSocket handshake: " + failure_message, rv,
390             std::nullopt);
391   return rv;
392 }
393 
OnFailure(const std::string & message,int net_error,std::optional<int> response_code)394 void WebSocketHttp2HandshakeStream::OnFailure(
395     const std::string& message,
396     int net_error,
397     std::optional<int> response_code) {
398   stream_request_->OnFailure(message, net_error, response_code);
399 }
400 
401 }  // namespace net
402