xref: /aosp_15_r20/external/cronet/net/websockets/websocket_http3_handshake_stream.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2023 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_http3_handshake_stream.h"
6 
7 #include <string_view>
8 #include <utility>
9 
10 #include "base/check.h"
11 #include "base/check_op.h"
12 #include "base/functional/bind.h"
13 #include "base/functional/callback.h"
14 #include "base/memory/scoped_refptr.h"
15 #include "base/strings/strcat.h"
16 #include "base/strings/stringprintf.h"
17 #include "base/time/time.h"
18 #include "net/base/ip_endpoint.h"
19 #include "net/http/http_request_headers.h"
20 #include "net/http/http_request_info.h"
21 #include "net/http/http_response_headers.h"
22 #include "net/http/http_response_info.h"
23 #include "net/http/http_status_code.h"
24 #include "net/spdy/spdy_http_utils.h"
25 #include "net/third_party/quiche/src/quiche/spdy/core/http2_header_block.h"
26 #include "net/traffic_annotation/network_traffic_annotation.h"
27 #include "net/websockets/websocket_basic_stream.h"
28 #include "net/websockets/websocket_deflate_predictor_impl.h"
29 #include "net/websockets/websocket_deflate_stream.h"
30 #include "net/websockets/websocket_handshake_constants.h"
31 #include "net/websockets/websocket_handshake_request_info.h"
32 
33 namespace net {
34 struct AlternativeService;
35 
36 namespace {
37 
ValidateStatus(const HttpResponseHeaders * headers)38 bool ValidateStatus(const HttpResponseHeaders* headers) {
39   return headers->GetStatusLine() == "HTTP/1.1 200";
40 }
41 
42 }  // namespace
43 
WebSocketHttp3HandshakeStream(std::unique_ptr<QuicChromiumClientSession::Handle> session,WebSocketStream::ConnectDelegate * connect_delegate,std::vector<std::string> requested_sub_protocols,std::vector<std::string> requested_extensions,WebSocketStreamRequestAPI * request,std::set<std::string> dns_aliases)44 WebSocketHttp3HandshakeStream::WebSocketHttp3HandshakeStream(
45     std::unique_ptr<QuicChromiumClientSession::Handle> session,
46     WebSocketStream::ConnectDelegate* connect_delegate,
47     std::vector<std::string> requested_sub_protocols,
48     std::vector<std::string> requested_extensions,
49     WebSocketStreamRequestAPI* request,
50     std::set<std::string> dns_aliases)
51     : session_(std::move(session)),
52       connect_delegate_(connect_delegate),
53       requested_sub_protocols_(std::move(requested_sub_protocols)),
54       requested_extensions_(std::move(requested_extensions)),
55       stream_request_(request),
56       dns_aliases_(std::move(dns_aliases)) {
57   DCHECK(connect_delegate);
58   DCHECK(request);
59 }
60 
~WebSocketHttp3HandshakeStream()61 WebSocketHttp3HandshakeStream::~WebSocketHttp3HandshakeStream() {
62   RecordHandshakeResult(result_);
63 }
64 
RegisterRequest(const HttpRequestInfo * request_info)65 void WebSocketHttp3HandshakeStream::RegisterRequest(
66     const HttpRequestInfo* request_info) {
67   DCHECK(request_info);
68   DCHECK(request_info->traffic_annotation.is_valid());
69   request_info_ = request_info;
70 }
71 
InitializeStream(bool can_send_early,RequestPriority priority,const NetLogWithSource & net_log,CompletionOnceCallback callback)72 int WebSocketHttp3HandshakeStream::InitializeStream(
73     bool can_send_early,
74     RequestPriority priority,
75     const NetLogWithSource& net_log,
76     CompletionOnceCallback callback) {
77   priority_ = priority;
78   net_log_ = net_log;
79   request_time_ = base::Time::Now();
80   return OK;
81 }
82 
SendRequest(const HttpRequestHeaders & request_headers,HttpResponseInfo * response,CompletionOnceCallback callback)83 int WebSocketHttp3HandshakeStream::SendRequest(
84     const HttpRequestHeaders& request_headers,
85     HttpResponseInfo* response,
86     CompletionOnceCallback callback) {
87   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketKey));
88   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketProtocol));
89   DCHECK(!request_headers.HasHeader(websockets::kSecWebSocketExtensions));
90   DCHECK(request_headers.HasHeader(HttpRequestHeaders::kOrigin));
91   DCHECK(request_headers.HasHeader(websockets::kUpgrade));
92   DCHECK(request_headers.HasHeader(HttpRequestHeaders::kConnection));
93   DCHECK(request_headers.HasHeader(websockets::kSecWebSocketVersion));
94 
95   if (!session_) {
96     constexpr int rv = ERR_CONNECTION_CLOSED;
97     OnFailure("Connection closed before sending request.", rv, std::nullopt);
98     return rv;
99   }
100 
101   http_response_info_ = response;
102 
103   IPEndPoint address;
104   int result = session_->GetPeerAddress(&address);
105   if (result != OK) {
106     OnFailure("Error getting IP address.", result, std::nullopt);
107     return result;
108   }
109   http_response_info_->remote_endpoint = address;
110 
111   auto request = std::make_unique<WebSocketHandshakeRequestInfo>(
112       request_info_->url, base::Time::Now());
113   request->headers = request_headers;
114 
115   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
116                             requested_extensions_, &request->headers);
117   AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
118                             requested_sub_protocols_, &request->headers);
119 
120   CreateSpdyHeadersFromHttpRequestForWebSocket(
121       request_info_->url, request->headers, &http3_request_headers_);
122 
123   connect_delegate_->OnStartOpeningHandshake(std::move(request));
124 
125   callback_ = std::move(callback);
126 
127   std::unique_ptr<WebSocketQuicStreamAdapter> stream_adapter =
128       session_->CreateWebSocketQuicStreamAdapter(
129           this,
130           base::BindOnce(
131               &WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest,
132               base::Unretained(this)),
133           NetworkTrafficAnnotationTag(request_info_->traffic_annotation));
134   if (!stream_adapter) {
135     return ERR_IO_PENDING;
136   }
137   ReceiveAdapterAndStartRequest(std::move(stream_adapter));
138   return OK;
139 }
140 
ReadResponseHeaders(CompletionOnceCallback callback)141 int WebSocketHttp3HandshakeStream::ReadResponseHeaders(
142     CompletionOnceCallback callback) {
143   if (stream_closed_) {
144     return stream_error_;
145   }
146 
147   if (response_headers_complete_) {
148     return ValidateResponse();
149   }
150 
151   callback_ = std::move(callback);
152   return ERR_IO_PENDING;
153 }
154 
155 // TODO(momoka): Implement this.
ReadResponseBody(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)156 int WebSocketHttp3HandshakeStream::ReadResponseBody(
157     IOBuffer* buf,
158     int buf_len,
159     CompletionOnceCallback callback) {
160   return OK;
161 }
162 
Close(bool not_reusable)163 void WebSocketHttp3HandshakeStream::Close(bool not_reusable) {
164   if (stream_adapter_) {
165     stream_adapter_->Disconnect();
166     stream_closed_ = true;
167     stream_error_ = ERR_CONNECTION_CLOSED;
168   }
169 }
170 
171 // TODO(momoka): Implement this.
IsResponseBodyComplete() const172 bool WebSocketHttp3HandshakeStream::IsResponseBodyComplete() const {
173   return false;
174 }
175 
176 // TODO(momoka): Implement this.
IsConnectionReused() const177 bool WebSocketHttp3HandshakeStream::IsConnectionReused() const {
178   return true;
179 }
180 
181 // TODO(momoka): Implement this.
SetConnectionReused()182 void WebSocketHttp3HandshakeStream::SetConnectionReused() {}
183 
184 // TODO(momoka): Implement this.
CanReuseConnection() const185 bool WebSocketHttp3HandshakeStream::CanReuseConnection() const {
186   return false;
187 }
188 
189 // TODO(momoka): Implement this.
GetTotalReceivedBytes() const190 int64_t WebSocketHttp3HandshakeStream::GetTotalReceivedBytes() const {
191   return 0;
192 }
193 
194 // TODO(momoka): Implement this.
GetTotalSentBytes() const195 int64_t WebSocketHttp3HandshakeStream::GetTotalSentBytes() const {
196   return 0;
197 }
198 
199 // TODO(momoka): Implement this.
GetAlternativeService(AlternativeService * alternative_service) const200 bool WebSocketHttp3HandshakeStream::GetAlternativeService(
201     AlternativeService* alternative_service) const {
202   return false;
203 }
204 
205 // TODO(momoka): Implement this.
GetLoadTimingInfo(LoadTimingInfo * load_timing_info) const206 bool WebSocketHttp3HandshakeStream::GetLoadTimingInfo(
207     LoadTimingInfo* load_timing_info) const {
208   return false;
209 }
210 
211 // TODO(momoka): Implement this.
GetSSLInfo(SSLInfo * ssl_info)212 void WebSocketHttp3HandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {}
213 
214 // TODO(momoka): Implement this.
GetRemoteEndpoint(IPEndPoint * endpoint)215 int WebSocketHttp3HandshakeStream::GetRemoteEndpoint(IPEndPoint* endpoint) {
216   return 0;
217 }
218 
219 // TODO(momoka): Implement this.
Drain(HttpNetworkSession * session)220 void WebSocketHttp3HandshakeStream::Drain(HttpNetworkSession* session) {}
221 
222 // TODO(momoka): Implement this.
SetPriority(RequestPriority priority)223 void WebSocketHttp3HandshakeStream::SetPriority(RequestPriority priority) {}
224 
225 // TODO(momoka): Implement this.
PopulateNetErrorDetails(NetErrorDetails * details)226 void WebSocketHttp3HandshakeStream::PopulateNetErrorDetails(
227     NetErrorDetails* details) {}
228 
229 // TODO(momoka): Implement this.
230 std::unique_ptr<HttpStream>
RenewStreamForAuth()231 WebSocketHttp3HandshakeStream::RenewStreamForAuth() {
232   return nullptr;
233 }
234 
235 // TODO(momoka): Implement this.
GetDnsAliases() const236 const std::set<std::string>& WebSocketHttp3HandshakeStream::GetDnsAliases()
237     const {
238   return dns_aliases_;
239 }
240 
241 // TODO(momoka): Implement this.
GetAcceptChViaAlps() const242 std::string_view WebSocketHttp3HandshakeStream::GetAcceptChViaAlps() const {
243   return {};
244 }
245 
246 // WebSocketHandshakeStreamBase methods.
247 
248 // TODO(momoka): Implement this.
Upgrade()249 std::unique_ptr<WebSocketStream> WebSocketHttp3HandshakeStream::Upgrade() {
250   DCHECK(extension_params_.get());
251 
252   stream_adapter_->clear_delegate();
253   std::unique_ptr<WebSocketStream> basic_stream =
254       std::make_unique<WebSocketBasicStream>(std::move(stream_adapter_),
255                                              nullptr, sub_protocol_,
256                                              extensions_, net_log_);
257 
258   if (!extension_params_->deflate_enabled) {
259     return basic_stream;
260   }
261 
262   return std::make_unique<WebSocketDeflateStream>(
263       std::move(basic_stream), extension_params_->deflate_parameters,
264       std::make_unique<WebSocketDeflatePredictorImpl>());
265 }
266 
CanReadFromStream() const267 bool WebSocketHttp3HandshakeStream::CanReadFromStream() const {
268   return stream_adapter_ && stream_adapter_->is_initialized();
269 }
270 
271 base::WeakPtr<WebSocketHandshakeStreamBase>
GetWeakPtr()272 WebSocketHttp3HandshakeStream::GetWeakPtr() {
273   return weak_ptr_factory_.GetWeakPtr();
274 }
275 
OnHeadersSent()276 void WebSocketHttp3HandshakeStream::OnHeadersSent() {
277   std::move(callback_).Run(OK);
278 }
279 
OnHeadersReceived(const spdy::Http2HeaderBlock & response_headers)280 void WebSocketHttp3HandshakeStream::OnHeadersReceived(
281     const spdy::Http2HeaderBlock& response_headers) {
282   DCHECK(!response_headers_complete_);
283   DCHECK(http_response_info_);
284 
285   response_headers_complete_ = true;
286 
287   const int rv =
288       SpdyHeadersToHttpResponse(response_headers, http_response_info_);
289   DCHECK_NE(rv, ERR_INCOMPLETE_HTTP2_HEADERS);
290 
291   // Do not store SSLInfo in the response here, HttpNetworkTransaction will take
292   // care of that part.
293   http_response_info_->was_alpn_negotiated = true;
294   http_response_info_->response_time = base::Time::Now();
295   http_response_info_->request_time = request_time_;
296   http_response_info_->connection_info = HttpConnectionInfo::kHTTP2;
297   http_response_info_->alpn_negotiated_protocol =
298       HttpConnectionInfoToString(http_response_info_->connection_info);
299 
300   if (callback_) {
301     std::move(callback_).Run(ValidateResponse());
302   }
303 }
304 
OnClose(int status)305 void WebSocketHttp3HandshakeStream::OnClose(int status) {
306   DCHECK(stream_adapter_);
307   DCHECK_GT(ERR_IO_PENDING, status);
308 
309   stream_closed_ = true;
310   stream_error_ = status;
311 
312   stream_adapter_.reset();
313 
314   // If response headers have already been received,
315   // then ValidateResponse() sets `result_`.
316   if (!response_headers_complete_) {
317     result_ = HandshakeResult::HTTP3_FAILED;
318   }
319 
320   OnFailure(base::StrCat({"Stream closed with error: ", ErrorToString(status)}),
321             status, std::nullopt);
322 
323   if (callback_) {
324     std::move(callback_).Run(status);
325   }
326 }
327 
ReceiveAdapterAndStartRequest(std::unique_ptr<WebSocketQuicStreamAdapter> adapter)328 void WebSocketHttp3HandshakeStream::ReceiveAdapterAndStartRequest(
329     std::unique_ptr<WebSocketQuicStreamAdapter> adapter) {
330   stream_adapter_ = std::move(adapter);
331   // WriteHeaders returns synchronously.
332   stream_adapter_->WriteHeaders(std::move(http3_request_headers_), false);
333 }
334 
ValidateResponse()335 int WebSocketHttp3HandshakeStream::ValidateResponse() {
336   DCHECK(http_response_info_);
337   const HttpResponseHeaders* headers = http_response_info_->headers.get();
338   const int response_code = headers->response_code();
339   switch (response_code) {
340     case HTTP_OK:
341       return ValidateUpgradeResponse(headers);
342 
343     // We need to pass these through for authentication to work.
344     case HTTP_UNAUTHORIZED:
345     case HTTP_PROXY_AUTHENTICATION_REQUIRED:
346       return OK;
347 
348     // Other status codes are potentially risky (see the warnings in the
349     // WHATWG WebSocket API spec) and so are dropped by default.
350     default:
351       OnFailure(
352           base::StringPrintf(
353               "Error during WebSocket handshake: Unexpected response code: %d",
354               headers->response_code()),
355           ERR_FAILED, headers->response_code());
356       result_ = HandshakeResult::HTTP3_INVALID_STATUS;
357       return ERR_INVALID_RESPONSE;
358   }
359 }
360 
ValidateUpgradeResponse(const HttpResponseHeaders * headers)361 int WebSocketHttp3HandshakeStream::ValidateUpgradeResponse(
362     const HttpResponseHeaders* headers) {
363   extension_params_ = std::make_unique<WebSocketExtensionParams>();
364   std::string failure_message;
365   if (!ValidateStatus(headers)) {
366     result_ = HandshakeResult::HTTP3_INVALID_STATUS;
367   } else if (!ValidateSubProtocol(headers, requested_sub_protocols_,
368                                   &sub_protocol_, &failure_message)) {
369     result_ = HandshakeResult::HTTP3_FAILED_SUBPROTO;
370   } else if (!ValidateExtensions(headers, &extensions_, &failure_message,
371                                  extension_params_.get())) {
372     result_ = HandshakeResult::HTTP3_FAILED_EXTENSIONS;
373   } else {
374     result_ = HandshakeResult::HTTP3_CONNECTED;
375     return OK;
376   }
377 
378   const int rv = ERR_INVALID_RESPONSE;
379   OnFailure("Error during WebSocket handshake: " + failure_message, rv,
380             std::nullopt);
381   return rv;
382 }
383 
384 // TODO(momoka): Implement this.
OnFailure(const std::string & message,int net_error,std::optional<int> response_code)385 void WebSocketHttp3HandshakeStream::OnFailure(
386     const std::string& message,
387     int net_error,
388     std::optional<int> response_code) {
389   stream_request_->OnFailure(message, net_error, response_code);
390 }
391 
392 }  // namespace net
393