1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_channel.h"
6
7 #include <limits.h> // for INT_MAX
8 #include <stddef.h>
9 #include <string.h>
10
11 #include <algorithm>
12 #include <iterator>
13 #include <ostream>
14 #include <string_view>
15 #include <utility>
16 #include <vector>
17
18 #include "base/big_endian.h"
19 #include "base/check.h"
20 #include "base/check_op.h"
21 #include "base/functional/bind.h"
22 #include "base/location.h"
23 #include "base/logging.h"
24 #include "base/memory/raw_ptr.h"
25 #include "base/numerics/byte_conversions.h"
26 #include "base/numerics/safe_conversions.h"
27 #include "base/ranges/algorithm.h"
28 #include "base/strings/stringprintf.h"
29 #include "base/time/time.h"
30 #include "base/values.h"
31 #include "net/base/io_buffer.h"
32 #include "net/base/net_errors.h"
33 #include "net/http/http_response_headers.h"
34 #include "net/log/net_log_event_type.h"
35 #include "net/log/net_log_with_source.h"
36 #include "net/traffic_annotation/network_traffic_annotation.h"
37 #include "net/websockets/websocket_errors.h"
38 #include "net/websockets/websocket_event_interface.h"
39 #include "net/websockets/websocket_frame.h"
40 #include "net/websockets/websocket_handshake_request_info.h"
41 #include "net/websockets/websocket_handshake_response_info.h"
42 #include "net/websockets/websocket_stream.h"
43
44 namespace net {
45 class AuthChallengeInfo;
46 class AuthCredentials;
47 class SSLInfo;
48
49 namespace {
50
51 using base::StreamingUtf8Validator;
52
53 constexpr size_t kWebSocketCloseCodeLength = 2;
54 // Timeout for waiting for the server to acknowledge a closing handshake.
55 constexpr int kClosingHandshakeTimeoutSeconds = 60;
56 // We wait for the server to close the underlying connection as recommended in
57 // https://tools.ietf.org/html/rfc6455#section-7.1.1
58 // We don't use 2MSL since there're server implementations that don't follow
59 // the recommendation and wait for the client to close the underlying
60 // connection. It leads to unnecessarily long time before CloseEvent
61 // invocation. We want to avoid this rather than strictly following the spec
62 // recommendation.
63 constexpr int kUnderlyingConnectionCloseTimeoutSeconds = 2;
64
65 using ChannelState = WebSocketChannel::ChannelState;
66
67 // Maximum close reason length = max control frame payload -
68 // status code length
69 // = 125 - 2
70 constexpr size_t kMaximumCloseReasonLength = 125 - kWebSocketCloseCodeLength;
71
72 // Check a close status code for strict compliance with RFC6455. This is only
73 // used for close codes received from a renderer that we are intending to send
74 // out over the network. See ParseClose() for the restrictions on incoming close
75 // codes. The |code| parameter is type int for convenience of implementation;
76 // the real type is uint16_t. Code 1005 is treated specially; it cannot be set
77 // explicitly by Javascript but the renderer uses it to indicate we should send
78 // a Close frame with no payload.
IsStrictlyValidCloseStatusCode(int code)79 bool IsStrictlyValidCloseStatusCode(int code) {
80 static constexpr int kInvalidRanges[] = {
81 // [BAD, OK)
82 0, 1000, // 1000 is the first valid code
83 1006, 1007, // 1006 MUST NOT be set.
84 1014, 3000, // 1014 unassigned; 1015 up to 2999 are reserved.
85 5000, 65536, // Codes above 5000 are invalid.
86 };
87 const int* const kInvalidRangesEnd =
88 kInvalidRanges + std::size(kInvalidRanges);
89
90 DCHECK_GE(code, 0);
91 DCHECK_LT(code, 65536);
92 const int* upper = std::upper_bound(kInvalidRanges, kInvalidRangesEnd, code);
93 DCHECK_NE(kInvalidRangesEnd, upper);
94 DCHECK_GT(upper, kInvalidRanges);
95 DCHECK_GT(*upper, code);
96 DCHECK_LE(*(upper - 1), code);
97 return ((upper - kInvalidRanges) % 2) == 0;
98 }
99
100 // Sets |name| to the name of the frame type for the given |opcode|. Note that
101 // for all of Text, Binary and Continuation opcode, this method returns
102 // "Data frame".
GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,std::string * name)103 void GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,
104 std::string* name) {
105 switch (opcode) {
106 case WebSocketFrameHeader::kOpCodeText: // fall-thru
107 case WebSocketFrameHeader::kOpCodeBinary: // fall-thru
108 case WebSocketFrameHeader::kOpCodeContinuation:
109 *name = "Data frame";
110 break;
111
112 case WebSocketFrameHeader::kOpCodePing:
113 *name = "Ping";
114 break;
115
116 case WebSocketFrameHeader::kOpCodePong:
117 *name = "Pong";
118 break;
119
120 case WebSocketFrameHeader::kOpCodeClose:
121 *name = "Close";
122 break;
123
124 default:
125 *name = "Unknown frame type";
126 break;
127 }
128
129 return;
130 }
131
NetLogFailParam(uint16_t code,std::string_view reason,std::string_view message)132 base::Value::Dict NetLogFailParam(uint16_t code,
133 std::string_view reason,
134 std::string_view message) {
135 base::Value::Dict dict;
136 dict.Set("code", code);
137 dict.Set("reason", reason);
138 dict.Set("internal_reason", message);
139 return dict;
140 }
141
142 class DependentIOBuffer : public WrappedIOBuffer {
143 public:
DependentIOBuffer(scoped_refptr<IOBufferWithSize> buffer,size_t offset)144 DependentIOBuffer(scoped_refptr<IOBufferWithSize> buffer, size_t offset)
145 : WrappedIOBuffer(buffer->span().subspan(offset)),
146 buffer_(std::move(buffer)) {}
147
148 private:
~DependentIOBuffer()149 ~DependentIOBuffer() override {
150 // Prevent `data_` from dangling should this destructor remove the
151 // last reference to `buffer_`.
152 data_ = nullptr;
153 }
154
155 scoped_refptr<IOBufferWithSize> buffer_;
156 };
157
158 } // namespace
159
160 // A class to encapsulate a set of frames and information about the size of
161 // those frames.
162 class WebSocketChannel::SendBuffer {
163 public:
164 SendBuffer() = default;
165
166 // Add a WebSocketFrame to the buffer and increase total_bytes_.
167 void AddFrame(std::unique_ptr<WebSocketFrame> chunk,
168 scoped_refptr<IOBuffer> buffer);
169
170 // Return a pointer to the frames_ for write purposes.
frames()171 std::vector<std::unique_ptr<WebSocketFrame>>* frames() { return &frames_; }
172
173 private:
174 // The frames_ that will be sent in the next call to WriteFrames().
175 std::vector<std::unique_ptr<WebSocketFrame>> frames_;
176 // References of each WebSocketFrame.data;
177 std::vector<scoped_refptr<IOBuffer>> buffers_;
178
179 // The total size of the payload data in |frames_|. This will be used to
180 // measure the throughput of the link.
181 // TODO(ricea): Measure the throughput of the link.
182 uint64_t total_bytes_ = 0;
183 };
184
AddFrame(std::unique_ptr<WebSocketFrame> frame,scoped_refptr<IOBuffer> buffer)185 void WebSocketChannel::SendBuffer::AddFrame(
186 std::unique_ptr<WebSocketFrame> frame,
187 scoped_refptr<IOBuffer> buffer) {
188 total_bytes_ += frame->header.payload_length;
189 frames_.push_back(std::move(frame));
190 buffers_.push_back(std::move(buffer));
191 }
192
193 // Implementation of WebSocketStream::ConnectDelegate that simply forwards the
194 // calls on to the WebSocketChannel that created it.
195 class WebSocketChannel::ConnectDelegate
196 : public WebSocketStream::ConnectDelegate {
197 public:
ConnectDelegate(WebSocketChannel * creator)198 explicit ConnectDelegate(WebSocketChannel* creator) : creator_(creator) {}
199
200 ConnectDelegate(const ConnectDelegate&) = delete;
201 ConnectDelegate& operator=(const ConnectDelegate&) = delete;
202
OnCreateRequest(URLRequest * request)203 void OnCreateRequest(URLRequest* request) override {
204 creator_->OnCreateURLRequest(request);
205 }
206
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)207 void OnURLRequestConnected(URLRequest* request,
208 const TransportInfo& info) override {
209 creator_->OnURLRequestConnected(request, info);
210 }
211
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)212 void OnSuccess(
213 std::unique_ptr<WebSocketStream> stream,
214 std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
215 creator_->OnConnectSuccess(std::move(stream), std::move(response));
216 // |this| may have been deleted.
217 }
218
OnFailure(const std::string & message,int net_error,std::optional<int> response_code)219 void OnFailure(const std::string& message,
220 int net_error,
221 std::optional<int> response_code) override {
222 creator_->OnConnectFailure(message, net_error, response_code);
223 // |this| has been deleted.
224 }
225
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)226 void OnStartOpeningHandshake(
227 std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
228 creator_->OnStartOpeningHandshake(std::move(request));
229 }
230
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)231 void OnSSLCertificateError(
232 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
233 ssl_error_callbacks,
234 int net_error,
235 const SSLInfo& ssl_info,
236 bool fatal) override {
237 creator_->OnSSLCertificateError(std::move(ssl_error_callbacks), net_error,
238 ssl_info, fatal);
239 }
240
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)241 int OnAuthRequired(const AuthChallengeInfo& auth_info,
242 scoped_refptr<HttpResponseHeaders> headers,
243 const IPEndPoint& remote_endpoint,
244 base::OnceCallback<void(const AuthCredentials*)> callback,
245 std::optional<AuthCredentials>* credentials) override {
246 return creator_->OnAuthRequired(auth_info, std::move(headers),
247 remote_endpoint, std::move(callback),
248 credentials);
249 }
250
251 private:
252 // A pointer to the WebSocketChannel that created this object. There is no
253 // danger of this pointer being stale, because deleting the WebSocketChannel
254 // cancels the connect process, deleting this object and preventing its
255 // callbacks from being called.
256 const raw_ptr<WebSocketChannel, DanglingUntriaged> creator_;
257 };
258
WebSocketChannel(std::unique_ptr<WebSocketEventInterface> event_interface,URLRequestContext * url_request_context)259 WebSocketChannel::WebSocketChannel(
260 std::unique_ptr<WebSocketEventInterface> event_interface,
261 URLRequestContext* url_request_context)
262 : event_interface_(std::move(event_interface)),
263 url_request_context_(url_request_context),
264 closing_handshake_timeout_(
265 base::Seconds(kClosingHandshakeTimeoutSeconds)),
266 underlying_connection_close_timeout_(
267 base::Seconds(kUnderlyingConnectionCloseTimeoutSeconds)) {}
268
~WebSocketChannel()269 WebSocketChannel::~WebSocketChannel() {
270 // The stream may hold a pointer to read_frames_, and so it needs to be
271 // destroyed first.
272 stream_.reset();
273 // The timer may have a callback pointing back to us, so stop it just in case
274 // someone decides to run the event loop from their destructor.
275 close_timer_.Stop();
276 }
277
SendAddChannelRequest(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation)278 void WebSocketChannel::SendAddChannelRequest(
279 const GURL& socket_url,
280 const std::vector<std::string>& requested_subprotocols,
281 const url::Origin& origin,
282 const SiteForCookies& site_for_cookies,
283 bool has_storage_access,
284 const IsolationInfo& isolation_info,
285 const HttpRequestHeaders& additional_headers,
286 NetworkTrafficAnnotationTag traffic_annotation) {
287 SendAddChannelRequestWithSuppliedCallback(
288 socket_url, requested_subprotocols, origin, site_for_cookies,
289 has_storage_access, isolation_info, additional_headers,
290 traffic_annotation,
291 base::BindOnce(&WebSocketStream::CreateAndConnectStream));
292 }
293
SetState(State new_state)294 void WebSocketChannel::SetState(State new_state) {
295 DCHECK_NE(state_, new_state);
296
297 state_ = new_state;
298 }
299
InClosingState() const300 bool WebSocketChannel::InClosingState() const {
301 // The state RECV_CLOSED is not supported here, because it is only used in one
302 // code path and should not leak into the code in general.
303 DCHECK_NE(RECV_CLOSED, state_)
304 << "InClosingState called with state_ == RECV_CLOSED";
305 return state_ == SEND_CLOSED || state_ == CLOSE_WAIT || state_ == CLOSED;
306 }
307
SendFrame(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,size_t buffer_size)308 WebSocketChannel::ChannelState WebSocketChannel::SendFrame(
309 bool fin,
310 WebSocketFrameHeader::OpCode op_code,
311 scoped_refptr<IOBuffer> buffer,
312 size_t buffer_size) {
313 DCHECK_LE(buffer_size, static_cast<size_t>(INT_MAX));
314 DCHECK(stream_) << "Got SendFrame without a connection established; fin="
315 << fin << " op_code=" << op_code
316 << " buffer_size=" << buffer_size;
317
318 if (InClosingState()) {
319 DVLOG(1) << "SendFrame called in state " << state_
320 << ". This may be a bug, or a harmless race.";
321 return CHANNEL_ALIVE;
322 }
323
324 DCHECK_EQ(state_, CONNECTED);
325
326 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(op_code))
327 << "Got SendFrame with bogus op_code " << op_code << " fin=" << fin
328 << " buffer_size=" << buffer_size;
329
330 if (op_code == WebSocketFrameHeader::kOpCodeText ||
331 (op_code == WebSocketFrameHeader::kOpCodeContinuation &&
332 sending_text_message_)) {
333 StreamingUtf8Validator::State state = outgoing_utf8_validator_.AddBytes(
334 base::make_span(buffer->bytes(), buffer_size));
335 if (state == StreamingUtf8Validator::INVALID ||
336 (state == StreamingUtf8Validator::VALID_MIDPOINT && fin)) {
337 // TODO(ricea): Kill renderer.
338 FailChannel("Browser sent a text frame containing invalid UTF-8",
339 kWebSocketErrorGoingAway, "");
340 return CHANNEL_DELETED;
341 // |this| has been deleted.
342 }
343 sending_text_message_ = !fin;
344 DCHECK(!fin || state == StreamingUtf8Validator::VALID_ENDPOINT);
345 }
346
347 return SendFrameInternal(fin, op_code, std::move(buffer), buffer_size);
348 // |this| may have been deleted.
349 }
350
StartClosingHandshake(uint16_t code,const std::string & reason)351 ChannelState WebSocketChannel::StartClosingHandshake(
352 uint16_t code,
353 const std::string& reason) {
354 if (InClosingState()) {
355 // When the associated renderer process is killed while the channel is in
356 // CLOSING state we reach here.
357 DVLOG(1) << "StartClosingHandshake called in state " << state_
358 << ". This may be a bug, or a harmless race.";
359 return CHANNEL_ALIVE;
360 }
361 if (has_received_close_frame_) {
362 // We reach here if the client wants to start a closing handshake while
363 // the browser is waiting for the client to consume incoming data frames
364 // before responding to a closing handshake initiated by the server.
365 // As the client doesn't want the data frames any more, we can respond to
366 // the closing handshake initiated by the server.
367 return RespondToClosingHandshake();
368 }
369 if (state_ == CONNECTING) {
370 // Abort the in-progress handshake and drop the connection immediately.
371 stream_request_.reset();
372 SetState(CLOSED);
373 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
374 return CHANNEL_DELETED;
375 }
376 DCHECK_EQ(state_, CONNECTED);
377
378 DCHECK(!close_timer_.IsRunning());
379 // This use of base::Unretained() is safe because we stop the timer in the
380 // destructor.
381 close_timer_.Start(
382 FROM_HERE, closing_handshake_timeout_,
383 base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
384
385 // Javascript actually only permits 1000 and 3000-4999, but the implementation
386 // itself may produce different codes. The length of |reason| is also checked
387 // by Javascript.
388 if (!IsStrictlyValidCloseStatusCode(code) ||
389 reason.size() > kMaximumCloseReasonLength) {
390 // "InternalServerError" is actually used for errors from any endpoint, per
391 // errata 3227 to RFC6455. If the renderer is sending us an invalid code or
392 // reason it must be malfunctioning in some way, and based on that we
393 // interpret this as an internal error.
394 if (SendClose(kWebSocketErrorInternalServerError, "") == CHANNEL_DELETED)
395 return CHANNEL_DELETED;
396 DCHECK_EQ(CONNECTED, state_);
397 SetState(SEND_CLOSED);
398 return CHANNEL_ALIVE;
399 }
400 if (SendClose(code, StreamingUtf8Validator::Validate(reason)
401 ? reason
402 : std::string()) == CHANNEL_DELETED)
403 return CHANNEL_DELETED;
404 DCHECK_EQ(CONNECTED, state_);
405 SetState(SEND_CLOSED);
406 return CHANNEL_ALIVE;
407 }
408
SendAddChannelRequestForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)409 void WebSocketChannel::SendAddChannelRequestForTesting(
410 const GURL& socket_url,
411 const std::vector<std::string>& requested_subprotocols,
412 const url::Origin& origin,
413 const SiteForCookies& site_for_cookies,
414 bool has_storage_access,
415 const IsolationInfo& isolation_info,
416 const HttpRequestHeaders& additional_headers,
417 NetworkTrafficAnnotationTag traffic_annotation,
418 WebSocketStreamRequestCreationCallback callback) {
419 SendAddChannelRequestWithSuppliedCallback(
420 socket_url, requested_subprotocols, origin, site_for_cookies,
421 has_storage_access, isolation_info, additional_headers,
422 traffic_annotation, std::move(callback));
423 }
424
SetClosingHandshakeTimeoutForTesting(base::TimeDelta delay)425 void WebSocketChannel::SetClosingHandshakeTimeoutForTesting(
426 base::TimeDelta delay) {
427 closing_handshake_timeout_ = delay;
428 }
429
SetUnderlyingConnectionCloseTimeoutForTesting(base::TimeDelta delay)430 void WebSocketChannel::SetUnderlyingConnectionCloseTimeoutForTesting(
431 base::TimeDelta delay) {
432 underlying_connection_close_timeout_ = delay;
433 }
434
SendAddChannelRequestWithSuppliedCallback(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)435 void WebSocketChannel::SendAddChannelRequestWithSuppliedCallback(
436 const GURL& socket_url,
437 const std::vector<std::string>& requested_subprotocols,
438 const url::Origin& origin,
439 const SiteForCookies& site_for_cookies,
440 bool has_storage_access,
441 const IsolationInfo& isolation_info,
442 const HttpRequestHeaders& additional_headers,
443 NetworkTrafficAnnotationTag traffic_annotation,
444 WebSocketStreamRequestCreationCallback callback) {
445 DCHECK_EQ(FRESHLY_CONSTRUCTED, state_);
446 if (!socket_url.SchemeIsWSOrWSS()) {
447 // TODO(ricea): Kill the renderer (this error should have been caught by
448 // Javascript).
449 event_interface_->OnFailChannel("Invalid scheme", ERR_FAILED, std::nullopt);
450 // |this| is deleted here.
451 return;
452 }
453 socket_url_ = socket_url;
454 auto connect_delegate = std::make_unique<ConnectDelegate>(this);
455 stream_request_ = std::move(callback).Run(
456 socket_url_, requested_subprotocols, origin, site_for_cookies,
457 has_storage_access, isolation_info, additional_headers,
458 url_request_context_.get(), NetLogWithSource(), traffic_annotation,
459 std::move(connect_delegate));
460 SetState(CONNECTING);
461 }
462
OnCreateURLRequest(URLRequest * request)463 void WebSocketChannel::OnCreateURLRequest(URLRequest* request) {
464 event_interface_->OnCreateURLRequest(request);
465 }
466
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)467 void WebSocketChannel::OnURLRequestConnected(URLRequest* request,
468 const TransportInfo& info) {
469 event_interface_->OnURLRequestConnected(request, info);
470 }
471
OnConnectSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)472 void WebSocketChannel::OnConnectSuccess(
473 std::unique_ptr<WebSocketStream> stream,
474 std::unique_ptr<WebSocketHandshakeResponseInfo> response) {
475 DCHECK(stream);
476 DCHECK_EQ(CONNECTING, state_);
477
478 stream_ = std::move(stream);
479
480 SetState(CONNECTED);
481
482 // |stream_request_| is not used once the connection has succeeded.
483 stream_request_.reset();
484
485 event_interface_->OnAddChannelResponse(
486 std::move(response), stream_->GetSubProtocol(), stream_->GetExtensions());
487 // |this| may have been deleted after OnAddChannelResponse.
488 }
489
OnConnectFailure(const std::string & message,int net_error,std::optional<int> response_code)490 void WebSocketChannel::OnConnectFailure(const std::string& message,
491 int net_error,
492 std::optional<int> response_code) {
493 DCHECK_EQ(CONNECTING, state_);
494
495 // Copy the message before we delete its owner.
496 std::string message_copy = message;
497
498 SetState(CLOSED);
499 stream_request_.reset();
500
501 event_interface_->OnFailChannel(message_copy, net_error, response_code);
502 // |this| has been deleted.
503 }
504
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)505 void WebSocketChannel::OnSSLCertificateError(
506 std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
507 ssl_error_callbacks,
508 int net_error,
509 const SSLInfo& ssl_info,
510 bool fatal) {
511 event_interface_->OnSSLCertificateError(
512 std::move(ssl_error_callbacks), socket_url_, net_error, ssl_info, fatal);
513 }
514
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)515 int WebSocketChannel::OnAuthRequired(
516 const AuthChallengeInfo& auth_info,
517 scoped_refptr<HttpResponseHeaders> response_headers,
518 const IPEndPoint& remote_endpoint,
519 base::OnceCallback<void(const AuthCredentials*)> callback,
520 std::optional<AuthCredentials>* credentials) {
521 return event_interface_->OnAuthRequired(
522 auth_info, std::move(response_headers), remote_endpoint,
523 std::move(callback), credentials);
524 }
525
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)526 void WebSocketChannel::OnStartOpeningHandshake(
527 std::unique_ptr<WebSocketHandshakeRequestInfo> request) {
528 event_interface_->OnStartOpeningHandshake(std::move(request));
529 }
530
WriteFrames()531 ChannelState WebSocketChannel::WriteFrames() {
532 int result = OK;
533 do {
534 // This use of base::Unretained is safe because this object owns the
535 // WebSocketStream and destroying it cancels all callbacks.
536 result = stream_->WriteFrames(
537 data_being_sent_->frames(),
538 base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnWriteDone),
539 base::Unretained(this), false));
540 if (result != ERR_IO_PENDING) {
541 if (OnWriteDone(true, result) == CHANNEL_DELETED)
542 return CHANNEL_DELETED;
543 // OnWriteDone() returns CHANNEL_DELETED on error. Here |state_| is
544 // guaranteed to be the same as before OnWriteDone() call.
545 }
546 } while (result == OK && data_being_sent_);
547 return CHANNEL_ALIVE;
548 }
549
OnWriteDone(bool synchronous,int result)550 ChannelState WebSocketChannel::OnWriteDone(bool synchronous, int result) {
551 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
552 DCHECK_NE(CONNECTING, state_);
553 DCHECK_NE(ERR_IO_PENDING, result);
554 DCHECK(data_being_sent_);
555 switch (result) {
556 case OK:
557 if (data_to_send_next_) {
558 data_being_sent_ = std::move(data_to_send_next_);
559 if (!synchronous)
560 return WriteFrames();
561 } else {
562 data_being_sent_.reset();
563 event_interface_->OnSendDataFrameDone();
564 }
565 return CHANNEL_ALIVE;
566
567 // If a recoverable error condition existed, it would go here.
568
569 default:
570 DCHECK_LT(result, 0)
571 << "WriteFrames() should only return OK or ERR_ codes";
572
573 stream_->Close();
574 SetState(CLOSED);
575 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
576 return CHANNEL_DELETED;
577 }
578 }
579
ReadFrames()580 ChannelState WebSocketChannel::ReadFrames() {
581 DCHECK(stream_);
582 DCHECK(state_ == CONNECTED || state_ == SEND_CLOSED || state_ == CLOSE_WAIT);
583 DCHECK(read_frames_.empty());
584 if (is_reading_) {
585 return CHANNEL_ALIVE;
586 }
587
588 if (!InClosingState() && has_received_close_frame_) {
589 DCHECK(!event_interface_->HasPendingDataFrames());
590 // We've been waiting for the client to consume the frames before
591 // responding to the closing handshake initiated by the server.
592 if (RespondToClosingHandshake() == CHANNEL_DELETED) {
593 return CHANNEL_DELETED;
594 }
595 }
596
597 // TODO(crbug.com/999235): Remove this CHECK.
598 CHECK(event_interface_);
599 while (!event_interface_->HasPendingDataFrames()) {
600 DCHECK(stream_);
601 // This use of base::Unretained is safe because this object owns the
602 // WebSocketStream, and any pending reads will be cancelled when it is
603 // destroyed.
604 const int result = stream_->ReadFrames(
605 &read_frames_,
606 base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnReadDone),
607 base::Unretained(this), false));
608 if (result == ERR_IO_PENDING) {
609 is_reading_ = true;
610 return CHANNEL_ALIVE;
611 }
612 if (OnReadDone(true, result) == CHANNEL_DELETED) {
613 return CHANNEL_DELETED;
614 }
615 DCHECK_NE(CLOSED, state_);
616 // TODO(crbug.com/999235): Remove this CHECK.
617 CHECK(event_interface_);
618 }
619 return CHANNEL_ALIVE;
620 }
621
OnReadDone(bool synchronous,int result)622 ChannelState WebSocketChannel::OnReadDone(bool synchronous, int result) {
623 DVLOG(3) << "WebSocketChannel::OnReadDone synchronous?" << synchronous
624 << ", result=" << result
625 << ", read_frames_.size=" << read_frames_.size();
626 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
627 DCHECK_NE(CONNECTING, state_);
628 DCHECK_NE(ERR_IO_PENDING, result);
629 switch (result) {
630 case OK:
631 // ReadFrames() must use ERR_CONNECTION_CLOSED for a closed connection
632 // with no data read, not an empty response.
633 DCHECK(!read_frames_.empty())
634 << "ReadFrames() returned OK, but nothing was read.";
635 for (auto& read_frame : read_frames_) {
636 if (HandleFrame(std::move(read_frame)) == CHANNEL_DELETED)
637 return CHANNEL_DELETED;
638 }
639 read_frames_.clear();
640 DCHECK_NE(CLOSED, state_);
641 if (!synchronous) {
642 is_reading_ = false;
643 if (!event_interface_->HasPendingDataFrames()) {
644 return ReadFrames();
645 }
646 }
647 return CHANNEL_ALIVE;
648
649 case ERR_WS_PROTOCOL_ERROR:
650 // This could be kWebSocketErrorProtocolError (specifically, non-minimal
651 // encoding of payload length) or kWebSocketErrorMessageTooBig, or an
652 // extension-specific error.
653 FailChannel("Invalid frame header", kWebSocketErrorProtocolError,
654 "WebSocket Protocol Error");
655 return CHANNEL_DELETED;
656
657 default:
658 DCHECK_LT(result, 0)
659 << "ReadFrames() should only return OK or ERR_ codes";
660
661 stream_->Close();
662 SetState(CLOSED);
663
664 uint16_t code = kWebSocketErrorAbnormalClosure;
665 std::string reason = "";
666 bool was_clean = false;
667 if (has_received_close_frame_) {
668 code = received_close_code_;
669 reason = received_close_reason_;
670 was_clean = (result == ERR_CONNECTION_CLOSED);
671 }
672
673 DoDropChannel(was_clean, code, reason);
674 return CHANNEL_DELETED;
675 }
676 }
677
HandleFrame(std::unique_ptr<WebSocketFrame> frame)678 ChannelState WebSocketChannel::HandleFrame(
679 std::unique_ptr<WebSocketFrame> frame) {
680 if (frame->header.masked) {
681 // RFC6455 Section 5.1 "A client MUST close a connection if it detects a
682 // masked frame."
683 FailChannel(
684 "A server must not mask any frames that it sends to the "
685 "client.",
686 kWebSocketErrorProtocolError, "Masked frame from server");
687 return CHANNEL_DELETED;
688 }
689 const WebSocketFrameHeader::OpCode opcode = frame->header.opcode;
690 DCHECK(!WebSocketFrameHeader::IsKnownControlOpCode(opcode) ||
691 frame->header.final);
692 if (frame->header.reserved1 || frame->header.reserved2 ||
693 frame->header.reserved3) {
694 FailChannel(
695 base::StringPrintf("One or more reserved bits are on: reserved1 = %d, "
696 "reserved2 = %d, reserved3 = %d",
697 static_cast<int>(frame->header.reserved1),
698 static_cast<int>(frame->header.reserved2),
699 static_cast<int>(frame->header.reserved3)),
700 kWebSocketErrorProtocolError, "Invalid reserved bit");
701 return CHANNEL_DELETED;
702 }
703
704 // Respond to the frame appropriately to its type.
705 return HandleFrameByState(
706 opcode, frame->header.final,
707 base::make_span(frame->payload, base::checked_cast<size_t>(
708 frame->header.payload_length)));
709 }
710
HandleFrameByState(const WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)711 ChannelState WebSocketChannel::HandleFrameByState(
712 const WebSocketFrameHeader::OpCode opcode,
713 bool final,
714 base::span<const char> payload) {
715 DCHECK_NE(RECV_CLOSED, state_)
716 << "HandleFrame() does not support being called re-entrantly from within "
717 "SendClose()";
718 DCHECK_NE(CLOSED, state_);
719 if (state_ == CLOSE_WAIT) {
720 std::string frame_name;
721 GetFrameTypeForOpcode(opcode, &frame_name);
722
723 // FailChannel() won't send another Close frame.
724 FailChannel(frame_name + " received after close",
725 kWebSocketErrorProtocolError, "");
726 return CHANNEL_DELETED;
727 }
728 switch (opcode) {
729 case WebSocketFrameHeader::kOpCodeText: // fall-thru
730 case WebSocketFrameHeader::kOpCodeBinary:
731 case WebSocketFrameHeader::kOpCodeContinuation:
732 return HandleDataFrame(opcode, final, std::move(payload));
733
734 case WebSocketFrameHeader::kOpCodePing:
735 DVLOG(1) << "Got Ping of size " << payload.size();
736 if (state_ == CONNECTED) {
737 auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload.size());
738 base::ranges::copy(payload, buffer->data());
739 return SendFrameInternal(true, WebSocketFrameHeader::kOpCodePong,
740 std::move(buffer), payload.size());
741 }
742 DVLOG(3) << "Ignored ping in state " << state_;
743 return CHANNEL_ALIVE;
744
745 case WebSocketFrameHeader::kOpCodePong:
746 DVLOG(1) << "Got Pong of size " << payload.size();
747 // There is no need to do anything with pong messages.
748 return CHANNEL_ALIVE;
749
750 case WebSocketFrameHeader::kOpCodeClose: {
751 uint16_t code = kWebSocketNormalClosure;
752 std::string reason;
753 std::string message;
754 if (!ParseClose(payload, &code, &reason, &message)) {
755 FailChannel(message, code, reason);
756 return CHANNEL_DELETED;
757 }
758 // TODO(ricea): Find a way to safely log the message from the close
759 // message (escape control codes and so on).
760 return HandleCloseFrame(code, reason);
761 }
762
763 default:
764 FailChannel(base::StringPrintf("Unrecognized frame opcode: %d", opcode),
765 kWebSocketErrorProtocolError, "Unknown opcode");
766 return CHANNEL_DELETED;
767 }
768 }
769
HandleDataFrame(WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)770 ChannelState WebSocketChannel::HandleDataFrame(
771 WebSocketFrameHeader::OpCode opcode,
772 bool final,
773 base::span<const char> payload) {
774 DVLOG(3) << "WebSocketChannel::HandleDataFrame opcode=" << opcode
775 << ", final?" << final << ", data=" << (void*)payload.data()
776 << ", size=" << payload.size();
777 if (state_ != CONNECTED) {
778 DVLOG(3) << "Ignored data packet received in state " << state_;
779 return CHANNEL_ALIVE;
780 }
781 if (has_received_close_frame_) {
782 DVLOG(3) << "Ignored data packet as we've received a close frame.";
783 return CHANNEL_ALIVE;
784 }
785 DCHECK(opcode == WebSocketFrameHeader::kOpCodeContinuation ||
786 opcode == WebSocketFrameHeader::kOpCodeText ||
787 opcode == WebSocketFrameHeader::kOpCodeBinary);
788 const bool got_continuation =
789 (opcode == WebSocketFrameHeader::kOpCodeContinuation);
790 if (got_continuation != expecting_to_handle_continuation_) {
791 const std::string console_log = got_continuation
792 ? "Received unexpected continuation frame."
793 : "Received start of new message but previous message is unfinished.";
794 const std::string reason = got_continuation
795 ? "Unexpected continuation"
796 : "Previous data frame unfinished";
797 FailChannel(console_log, kWebSocketErrorProtocolError, reason);
798 return CHANNEL_DELETED;
799 }
800 expecting_to_handle_continuation_ = !final;
801 WebSocketFrameHeader::OpCode opcode_to_send = opcode;
802 if (!initial_frame_forwarded_ &&
803 opcode == WebSocketFrameHeader::kOpCodeContinuation) {
804 opcode_to_send = receiving_text_message_
805 ? WebSocketFrameHeader::kOpCodeText
806 : WebSocketFrameHeader::kOpCodeBinary;
807 }
808 if (opcode == WebSocketFrameHeader::kOpCodeText ||
809 (opcode == WebSocketFrameHeader::kOpCodeContinuation &&
810 receiving_text_message_)) {
811 // This call is not redundant when size == 0 because it tells us what
812 // the current state is.
813 StreamingUtf8Validator::State state =
814 incoming_utf8_validator_.AddBytes(base::as_byte_span(payload));
815 if (state == StreamingUtf8Validator::INVALID ||
816 (state == StreamingUtf8Validator::VALID_MIDPOINT && final)) {
817 FailChannel("Could not decode a text frame as UTF-8.",
818 kWebSocketErrorProtocolError, "Invalid UTF-8 in text frame");
819 return CHANNEL_DELETED;
820 }
821 receiving_text_message_ = !final;
822 DCHECK(!final || state == StreamingUtf8Validator::VALID_ENDPOINT);
823 }
824 if (payload.size() == 0U && !final)
825 return CHANNEL_ALIVE;
826
827 initial_frame_forwarded_ = !final;
828 // Sends the received frame to the renderer process.
829 event_interface_->OnDataFrame(final, opcode_to_send, payload);
830 return CHANNEL_ALIVE;
831 }
832
HandleCloseFrame(uint16_t code,const std::string & reason)833 ChannelState WebSocketChannel::HandleCloseFrame(uint16_t code,
834 const std::string& reason) {
835 DVLOG(1) << "Got Close with code " << code;
836 switch (state_) {
837 case CONNECTED:
838 has_received_close_frame_ = true;
839 received_close_code_ = code;
840 received_close_reason_ = reason;
841 if (event_interface_->HasPendingDataFrames()) {
842 // We have some data to be sent to the renderer before sending this
843 // frame.
844 return CHANNEL_ALIVE;
845 }
846 return RespondToClosingHandshake();
847
848 case SEND_CLOSED:
849 SetState(CLOSE_WAIT);
850 DCHECK(close_timer_.IsRunning());
851 close_timer_.Stop();
852 // This use of base::Unretained() is safe because we stop the timer
853 // in the destructor.
854 close_timer_.Start(FROM_HERE, underlying_connection_close_timeout_,
855 base::BindOnce(&WebSocketChannel::CloseTimeout,
856 base::Unretained(this)));
857
858 // From RFC6455 section 7.1.5: "Each endpoint
859 // will see the status code sent by the other end as _The WebSocket
860 // Connection Close Code_."
861 has_received_close_frame_ = true;
862 received_close_code_ = code;
863 received_close_reason_ = reason;
864 break;
865
866 default:
867 LOG(DFATAL) << "Got Close in unexpected state " << state_;
868 break;
869 }
870 return CHANNEL_ALIVE;
871 }
872
RespondToClosingHandshake()873 ChannelState WebSocketChannel::RespondToClosingHandshake() {
874 DCHECK(has_received_close_frame_);
875 DCHECK_EQ(CONNECTED, state_);
876 SetState(RECV_CLOSED);
877 if (SendClose(received_close_code_, received_close_reason_) ==
878 CHANNEL_DELETED)
879 return CHANNEL_DELETED;
880 DCHECK_EQ(RECV_CLOSED, state_);
881
882 SetState(CLOSE_WAIT);
883 DCHECK(!close_timer_.IsRunning());
884 // This use of base::Unretained() is safe because we stop the timer
885 // in the destructor.
886 close_timer_.Start(
887 FROM_HERE, underlying_connection_close_timeout_,
888 base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
889
890 event_interface_->OnClosingHandshake();
891 return CHANNEL_ALIVE;
892 }
893
SendFrameInternal(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,uint64_t buffer_size)894 ChannelState WebSocketChannel::SendFrameInternal(
895 bool fin,
896 WebSocketFrameHeader::OpCode op_code,
897 scoped_refptr<IOBuffer> buffer,
898 uint64_t buffer_size) {
899 DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
900 DCHECK(stream_);
901
902 auto frame = std::make_unique<WebSocketFrame>(op_code);
903 WebSocketFrameHeader& header = frame->header;
904 header.final = fin;
905 header.masked = true;
906 header.payload_length = buffer_size;
907 frame->payload = buffer->data();
908
909 if (data_being_sent_) {
910 // Either the link to the WebSocket server is saturated, or several messages
911 // are being sent in a batch.
912 if (!data_to_send_next_)
913 data_to_send_next_ = std::make_unique<SendBuffer>();
914 data_to_send_next_->AddFrame(std::move(frame), std::move(buffer));
915 return CHANNEL_ALIVE;
916 }
917
918 data_being_sent_ = std::make_unique<SendBuffer>();
919 data_being_sent_->AddFrame(std::move(frame), std::move(buffer));
920 return WriteFrames();
921 }
922
FailChannel(const std::string & message,uint16_t code,const std::string & reason)923 void WebSocketChannel::FailChannel(const std::string& message,
924 uint16_t code,
925 const std::string& reason) {
926 DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
927 DCHECK_NE(CONNECTING, state_);
928 DCHECK_NE(CLOSED, state_);
929
930 stream_->GetNetLogWithSource().AddEvent(
931 net::NetLogEventType::WEBSOCKET_INVALID_FRAME,
932 [&] { return NetLogFailParam(code, reason, message); });
933
934 if (state_ == CONNECTED) {
935 if (SendClose(code, reason) == CHANNEL_DELETED)
936 return;
937 }
938
939 // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates the browser
940 // should close the connection itself without waiting for the closing
941 // handshake.
942 stream_->Close();
943 SetState(CLOSED);
944 event_interface_->OnFailChannel(message, ERR_FAILED, std::nullopt);
945 }
946
SendClose(uint16_t code,const std::string & reason)947 ChannelState WebSocketChannel::SendClose(uint16_t code,
948 const std::string& reason) {
949 DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
950 DCHECK_LE(reason.size(), kMaximumCloseReasonLength);
951 scoped_refptr<IOBuffer> body;
952 uint64_t size = 0;
953 if (code == kWebSocketErrorNoStatusReceived) {
954 // Special case: translate kWebSocketErrorNoStatusReceived into a Close
955 // frame with no payload.
956 DCHECK(reason.empty());
957 body = base::MakeRefCounted<IOBufferWithSize>();
958 } else {
959 const size_t payload_length = kWebSocketCloseCodeLength + reason.length();
960 body = base::MakeRefCounted<IOBufferWithSize>(payload_length);
961 size = payload_length;
962 auto [code_span, body_span] =
963 body->span().split_at<kWebSocketCloseCodeLength>();
964 base::as_writable_bytes(code_span).copy_from(base::U16ToBigEndian(code));
965 static_assert(sizeof(code) == kWebSocketCloseCodeLength,
966 "they should both be two");
967 body_span.copy_from(reason);
968 }
969
970 return SendFrameInternal(true, WebSocketFrameHeader::kOpCodeClose,
971 std::move(body), size);
972 }
973
ParseClose(base::span<const char> payload,uint16_t * code,std::string * reason,std::string * message)974 bool WebSocketChannel::ParseClose(base::span<const char> payload,
975 uint16_t* code,
976 std::string* reason,
977 std::string* message) {
978 const uint64_t size = static_cast<uint64_t>(payload.size());
979 reason->clear();
980 if (size < kWebSocketCloseCodeLength) {
981 if (size == 0U) {
982 *code = kWebSocketErrorNoStatusReceived;
983 return true;
984 }
985
986 DVLOG(1) << "Close frame with payload size " << size << " received "
987 << "(the first byte is " << std::hex
988 << static_cast<int>(payload[0]) << ")";
989 *code = kWebSocketErrorProtocolError;
990 *message =
991 "Received a broken close frame containing an invalid size body.";
992 return false;
993 }
994
995 const char* data = payload.data();
996 uint16_t unchecked_code =
997 base::U16FromBigEndian(base::as_byte_span(payload).first<2>());
998 static_assert(sizeof(unchecked_code) == kWebSocketCloseCodeLength,
999 "they should both be two bytes");
1000
1001 switch (unchecked_code) {
1002 case kWebSocketErrorNoStatusReceived:
1003 case kWebSocketErrorAbnormalClosure:
1004 case kWebSocketErrorTlsHandshake:
1005 *code = kWebSocketErrorProtocolError;
1006 *message =
1007 "Received a broken close frame containing a reserved status code.";
1008 return false;
1009
1010 default:
1011 *code = unchecked_code;
1012 break;
1013 }
1014
1015 std::string text(data + kWebSocketCloseCodeLength, data + size);
1016 if (StreamingUtf8Validator::Validate(text)) {
1017 reason->swap(text);
1018 return true;
1019 }
1020
1021 *code = kWebSocketErrorProtocolError;
1022 *reason = "Invalid UTF-8 in Close frame";
1023 *message = "Received a broken close frame containing invalid UTF-8.";
1024 return false;
1025 }
1026
DoDropChannel(bool was_clean,uint16_t code,const std::string & reason)1027 void WebSocketChannel::DoDropChannel(bool was_clean,
1028 uint16_t code,
1029 const std::string& reason) {
1030 event_interface_->OnDropChannel(was_clean, code, reason);
1031 }
1032
CloseTimeout()1033 void WebSocketChannel::CloseTimeout() {
1034 stream_->GetNetLogWithSource().AddEvent(
1035 net::NetLogEventType::WEBSOCKET_CLOSE_TIMEOUT);
1036 stream_->Close();
1037 SetState(CLOSED);
1038 if (has_received_close_frame_) {
1039 DoDropChannel(true, received_close_code_, received_close_reason_);
1040 } else {
1041 DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
1042 }
1043 // |this| has been deleted.
1044 }
1045
1046 } // namespace net
1047