xref: /aosp_15_r20/external/cronet/net/websockets/websocket_channel.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2013 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_channel.h"
6 
7 #include <limits.h>  // for INT_MAX
8 #include <stddef.h>
9 #include <string.h>
10 
11 #include <algorithm>
12 #include <iterator>
13 #include <ostream>
14 #include <string_view>
15 #include <utility>
16 #include <vector>
17 
18 #include "base/big_endian.h"
19 #include "base/check.h"
20 #include "base/check_op.h"
21 #include "base/functional/bind.h"
22 #include "base/location.h"
23 #include "base/logging.h"
24 #include "base/memory/raw_ptr.h"
25 #include "base/numerics/byte_conversions.h"
26 #include "base/numerics/safe_conversions.h"
27 #include "base/ranges/algorithm.h"
28 #include "base/strings/stringprintf.h"
29 #include "base/time/time.h"
30 #include "base/values.h"
31 #include "net/base/io_buffer.h"
32 #include "net/base/net_errors.h"
33 #include "net/http/http_response_headers.h"
34 #include "net/log/net_log_event_type.h"
35 #include "net/log/net_log_with_source.h"
36 #include "net/traffic_annotation/network_traffic_annotation.h"
37 #include "net/websockets/websocket_errors.h"
38 #include "net/websockets/websocket_event_interface.h"
39 #include "net/websockets/websocket_frame.h"
40 #include "net/websockets/websocket_handshake_request_info.h"
41 #include "net/websockets/websocket_handshake_response_info.h"
42 #include "net/websockets/websocket_stream.h"
43 
44 namespace net {
45 class AuthChallengeInfo;
46 class AuthCredentials;
47 class SSLInfo;
48 
49 namespace {
50 
51 using base::StreamingUtf8Validator;
52 
53 constexpr size_t kWebSocketCloseCodeLength = 2;
54 // Timeout for waiting for the server to acknowledge a closing handshake.
55 constexpr int kClosingHandshakeTimeoutSeconds = 60;
56 // We wait for the server to close the underlying connection as recommended in
57 // https://tools.ietf.org/html/rfc6455#section-7.1.1
58 // We don't use 2MSL since there're server implementations that don't follow
59 // the recommendation and wait for the client to close the underlying
60 // connection. It leads to unnecessarily long time before CloseEvent
61 // invocation. We want to avoid this rather than strictly following the spec
62 // recommendation.
63 constexpr int kUnderlyingConnectionCloseTimeoutSeconds = 2;
64 
65 using ChannelState = WebSocketChannel::ChannelState;
66 
67 // Maximum close reason length = max control frame payload -
68 //                               status code length
69 //                             = 125 - 2
70 constexpr size_t kMaximumCloseReasonLength = 125 - kWebSocketCloseCodeLength;
71 
72 // Check a close status code for strict compliance with RFC6455. This is only
73 // used for close codes received from a renderer that we are intending to send
74 // out over the network. See ParseClose() for the restrictions on incoming close
75 // codes. The |code| parameter is type int for convenience of implementation;
76 // the real type is uint16_t. Code 1005 is treated specially; it cannot be set
77 // explicitly by Javascript but the renderer uses it to indicate we should send
78 // a Close frame with no payload.
IsStrictlyValidCloseStatusCode(int code)79 bool IsStrictlyValidCloseStatusCode(int code) {
80   static constexpr int kInvalidRanges[] = {
81       // [BAD, OK)
82       0,    1000,   // 1000 is the first valid code
83       1006, 1007,   // 1006 MUST NOT be set.
84       1014, 3000,   // 1014 unassigned; 1015 up to 2999 are reserved.
85       5000, 65536,  // Codes above 5000 are invalid.
86   };
87   const int* const kInvalidRangesEnd =
88       kInvalidRanges + std::size(kInvalidRanges);
89 
90   DCHECK_GE(code, 0);
91   DCHECK_LT(code, 65536);
92   const int* upper = std::upper_bound(kInvalidRanges, kInvalidRangesEnd, code);
93   DCHECK_NE(kInvalidRangesEnd, upper);
94   DCHECK_GT(upper, kInvalidRanges);
95   DCHECK_GT(*upper, code);
96   DCHECK_LE(*(upper - 1), code);
97   return ((upper - kInvalidRanges) % 2) == 0;
98 }
99 
100 // Sets |name| to the name of the frame type for the given |opcode|. Note that
101 // for all of Text, Binary and Continuation opcode, this method returns
102 // "Data frame".
GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,std::string * name)103 void GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode,
104                            std::string* name) {
105   switch (opcode) {
106     case WebSocketFrameHeader::kOpCodeText:    // fall-thru
107     case WebSocketFrameHeader::kOpCodeBinary:  // fall-thru
108     case WebSocketFrameHeader::kOpCodeContinuation:
109       *name = "Data frame";
110       break;
111 
112     case WebSocketFrameHeader::kOpCodePing:
113       *name = "Ping";
114       break;
115 
116     case WebSocketFrameHeader::kOpCodePong:
117       *name = "Pong";
118       break;
119 
120     case WebSocketFrameHeader::kOpCodeClose:
121       *name = "Close";
122       break;
123 
124     default:
125       *name = "Unknown frame type";
126       break;
127   }
128 
129   return;
130 }
131 
NetLogFailParam(uint16_t code,std::string_view reason,std::string_view message)132 base::Value::Dict NetLogFailParam(uint16_t code,
133                                   std::string_view reason,
134                                   std::string_view message) {
135   base::Value::Dict dict;
136   dict.Set("code", code);
137   dict.Set("reason", reason);
138   dict.Set("internal_reason", message);
139   return dict;
140 }
141 
142 class DependentIOBuffer : public WrappedIOBuffer {
143  public:
DependentIOBuffer(scoped_refptr<IOBufferWithSize> buffer,size_t offset)144   DependentIOBuffer(scoped_refptr<IOBufferWithSize> buffer, size_t offset)
145       : WrappedIOBuffer(buffer->span().subspan(offset)),
146         buffer_(std::move(buffer)) {}
147 
148  private:
~DependentIOBuffer()149   ~DependentIOBuffer() override {
150     // Prevent `data_` from dangling should this destructor remove the
151     // last reference to `buffer_`.
152     data_ = nullptr;
153   }
154 
155   scoped_refptr<IOBufferWithSize> buffer_;
156 };
157 
158 }  // namespace
159 
160 // A class to encapsulate a set of frames and information about the size of
161 // those frames.
162 class WebSocketChannel::SendBuffer {
163  public:
164   SendBuffer() = default;
165 
166   // Add a WebSocketFrame to the buffer and increase total_bytes_.
167   void AddFrame(std::unique_ptr<WebSocketFrame> chunk,
168                 scoped_refptr<IOBuffer> buffer);
169 
170   // Return a pointer to the frames_ for write purposes.
frames()171   std::vector<std::unique_ptr<WebSocketFrame>>* frames() { return &frames_; }
172 
173  private:
174   // The frames_ that will be sent in the next call to WriteFrames().
175   std::vector<std::unique_ptr<WebSocketFrame>> frames_;
176   // References of each WebSocketFrame.data;
177   std::vector<scoped_refptr<IOBuffer>> buffers_;
178 
179   // The total size of the payload data in |frames_|. This will be used to
180   // measure the throughput of the link.
181   // TODO(ricea): Measure the throughput of the link.
182   uint64_t total_bytes_ = 0;
183 };
184 
AddFrame(std::unique_ptr<WebSocketFrame> frame,scoped_refptr<IOBuffer> buffer)185 void WebSocketChannel::SendBuffer::AddFrame(
186     std::unique_ptr<WebSocketFrame> frame,
187     scoped_refptr<IOBuffer> buffer) {
188   total_bytes_ += frame->header.payload_length;
189   frames_.push_back(std::move(frame));
190   buffers_.push_back(std::move(buffer));
191 }
192 
193 // Implementation of WebSocketStream::ConnectDelegate that simply forwards the
194 // calls on to the WebSocketChannel that created it.
195 class WebSocketChannel::ConnectDelegate
196     : public WebSocketStream::ConnectDelegate {
197  public:
ConnectDelegate(WebSocketChannel * creator)198   explicit ConnectDelegate(WebSocketChannel* creator) : creator_(creator) {}
199 
200   ConnectDelegate(const ConnectDelegate&) = delete;
201   ConnectDelegate& operator=(const ConnectDelegate&) = delete;
202 
OnCreateRequest(URLRequest * request)203   void OnCreateRequest(URLRequest* request) override {
204     creator_->OnCreateURLRequest(request);
205   }
206 
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)207   void OnURLRequestConnected(URLRequest* request,
208                              const TransportInfo& info) override {
209     creator_->OnURLRequestConnected(request, info);
210   }
211 
OnSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)212   void OnSuccess(
213       std::unique_ptr<WebSocketStream> stream,
214       std::unique_ptr<WebSocketHandshakeResponseInfo> response) override {
215     creator_->OnConnectSuccess(std::move(stream), std::move(response));
216     // |this| may have been deleted.
217   }
218 
OnFailure(const std::string & message,int net_error,std::optional<int> response_code)219   void OnFailure(const std::string& message,
220                  int net_error,
221                  std::optional<int> response_code) override {
222     creator_->OnConnectFailure(message, net_error, response_code);
223     // |this| has been deleted.
224   }
225 
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)226   void OnStartOpeningHandshake(
227       std::unique_ptr<WebSocketHandshakeRequestInfo> request) override {
228     creator_->OnStartOpeningHandshake(std::move(request));
229   }
230 
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)231   void OnSSLCertificateError(
232       std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
233           ssl_error_callbacks,
234       int net_error,
235       const SSLInfo& ssl_info,
236       bool fatal) override {
237     creator_->OnSSLCertificateError(std::move(ssl_error_callbacks), net_error,
238                                     ssl_info, fatal);
239   }
240 
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)241   int OnAuthRequired(const AuthChallengeInfo& auth_info,
242                      scoped_refptr<HttpResponseHeaders> headers,
243                      const IPEndPoint& remote_endpoint,
244                      base::OnceCallback<void(const AuthCredentials*)> callback,
245                      std::optional<AuthCredentials>* credentials) override {
246     return creator_->OnAuthRequired(auth_info, std::move(headers),
247                                     remote_endpoint, std::move(callback),
248                                     credentials);
249   }
250 
251  private:
252   // A pointer to the WebSocketChannel that created this object. There is no
253   // danger of this pointer being stale, because deleting the WebSocketChannel
254   // cancels the connect process, deleting this object and preventing its
255   // callbacks from being called.
256   const raw_ptr<WebSocketChannel, DanglingUntriaged> creator_;
257 };
258 
WebSocketChannel(std::unique_ptr<WebSocketEventInterface> event_interface,URLRequestContext * url_request_context)259 WebSocketChannel::WebSocketChannel(
260     std::unique_ptr<WebSocketEventInterface> event_interface,
261     URLRequestContext* url_request_context)
262     : event_interface_(std::move(event_interface)),
263       url_request_context_(url_request_context),
264       closing_handshake_timeout_(
265           base::Seconds(kClosingHandshakeTimeoutSeconds)),
266       underlying_connection_close_timeout_(
267           base::Seconds(kUnderlyingConnectionCloseTimeoutSeconds)) {}
268 
~WebSocketChannel()269 WebSocketChannel::~WebSocketChannel() {
270   // The stream may hold a pointer to read_frames_, and so it needs to be
271   // destroyed first.
272   stream_.reset();
273   // The timer may have a callback pointing back to us, so stop it just in case
274   // someone decides to run the event loop from their destructor.
275   close_timer_.Stop();
276 }
277 
SendAddChannelRequest(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation)278 void WebSocketChannel::SendAddChannelRequest(
279     const GURL& socket_url,
280     const std::vector<std::string>& requested_subprotocols,
281     const url::Origin& origin,
282     const SiteForCookies& site_for_cookies,
283     bool has_storage_access,
284     const IsolationInfo& isolation_info,
285     const HttpRequestHeaders& additional_headers,
286     NetworkTrafficAnnotationTag traffic_annotation) {
287   SendAddChannelRequestWithSuppliedCallback(
288       socket_url, requested_subprotocols, origin, site_for_cookies,
289       has_storage_access, isolation_info, additional_headers,
290       traffic_annotation,
291       base::BindOnce(&WebSocketStream::CreateAndConnectStream));
292 }
293 
SetState(State new_state)294 void WebSocketChannel::SetState(State new_state) {
295   DCHECK_NE(state_, new_state);
296 
297   state_ = new_state;
298 }
299 
InClosingState() const300 bool WebSocketChannel::InClosingState() const {
301   // The state RECV_CLOSED is not supported here, because it is only used in one
302   // code path and should not leak into the code in general.
303   DCHECK_NE(RECV_CLOSED, state_)
304       << "InClosingState called with state_ == RECV_CLOSED";
305   return state_ == SEND_CLOSED || state_ == CLOSE_WAIT || state_ == CLOSED;
306 }
307 
SendFrame(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,size_t buffer_size)308 WebSocketChannel::ChannelState WebSocketChannel::SendFrame(
309     bool fin,
310     WebSocketFrameHeader::OpCode op_code,
311     scoped_refptr<IOBuffer> buffer,
312     size_t buffer_size) {
313   DCHECK_LE(buffer_size, static_cast<size_t>(INT_MAX));
314   DCHECK(stream_) << "Got SendFrame without a connection established; fin="
315                   << fin << " op_code=" << op_code
316                   << " buffer_size=" << buffer_size;
317 
318   if (InClosingState()) {
319     DVLOG(1) << "SendFrame called in state " << state_
320              << ". This may be a bug, or a harmless race.";
321     return CHANNEL_ALIVE;
322   }
323 
324   DCHECK_EQ(state_, CONNECTED);
325 
326   DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(op_code))
327       << "Got SendFrame with bogus op_code " << op_code << " fin=" << fin
328       << " buffer_size=" << buffer_size;
329 
330   if (op_code == WebSocketFrameHeader::kOpCodeText ||
331       (op_code == WebSocketFrameHeader::kOpCodeContinuation &&
332        sending_text_message_)) {
333     StreamingUtf8Validator::State state = outgoing_utf8_validator_.AddBytes(
334         base::make_span(buffer->bytes(), buffer_size));
335     if (state == StreamingUtf8Validator::INVALID ||
336         (state == StreamingUtf8Validator::VALID_MIDPOINT && fin)) {
337       // TODO(ricea): Kill renderer.
338       FailChannel("Browser sent a text frame containing invalid UTF-8",
339                   kWebSocketErrorGoingAway, "");
340       return CHANNEL_DELETED;
341       // |this| has been deleted.
342     }
343     sending_text_message_ = !fin;
344     DCHECK(!fin || state == StreamingUtf8Validator::VALID_ENDPOINT);
345   }
346 
347   return SendFrameInternal(fin, op_code, std::move(buffer), buffer_size);
348   // |this| may have been deleted.
349 }
350 
StartClosingHandshake(uint16_t code,const std::string & reason)351 ChannelState WebSocketChannel::StartClosingHandshake(
352     uint16_t code,
353     const std::string& reason) {
354   if (InClosingState()) {
355     // When the associated renderer process is killed while the channel is in
356     // CLOSING state we reach here.
357     DVLOG(1) << "StartClosingHandshake called in state " << state_
358              << ". This may be a bug, or a harmless race.";
359     return CHANNEL_ALIVE;
360   }
361   if (has_received_close_frame_) {
362     // We reach here if the client wants to start a closing handshake while
363     // the browser is waiting for the client to consume incoming data frames
364     // before responding to a closing handshake initiated by the server.
365     // As the client doesn't want the data frames any more, we can respond to
366     // the closing handshake initiated by the server.
367     return RespondToClosingHandshake();
368   }
369   if (state_ == CONNECTING) {
370     // Abort the in-progress handshake and drop the connection immediately.
371     stream_request_.reset();
372     SetState(CLOSED);
373     DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
374     return CHANNEL_DELETED;
375   }
376   DCHECK_EQ(state_, CONNECTED);
377 
378   DCHECK(!close_timer_.IsRunning());
379   // This use of base::Unretained() is safe because we stop the timer in the
380   // destructor.
381   close_timer_.Start(
382       FROM_HERE, closing_handshake_timeout_,
383       base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
384 
385   // Javascript actually only permits 1000 and 3000-4999, but the implementation
386   // itself may produce different codes. The length of |reason| is also checked
387   // by Javascript.
388   if (!IsStrictlyValidCloseStatusCode(code) ||
389       reason.size() > kMaximumCloseReasonLength) {
390     // "InternalServerError" is actually used for errors from any endpoint, per
391     // errata 3227 to RFC6455. If the renderer is sending us an invalid code or
392     // reason it must be malfunctioning in some way, and based on that we
393     // interpret this as an internal error.
394     if (SendClose(kWebSocketErrorInternalServerError, "") == CHANNEL_DELETED)
395       return CHANNEL_DELETED;
396     DCHECK_EQ(CONNECTED, state_);
397     SetState(SEND_CLOSED);
398     return CHANNEL_ALIVE;
399   }
400   if (SendClose(code, StreamingUtf8Validator::Validate(reason)
401                           ? reason
402                           : std::string()) == CHANNEL_DELETED)
403     return CHANNEL_DELETED;
404   DCHECK_EQ(CONNECTED, state_);
405   SetState(SEND_CLOSED);
406   return CHANNEL_ALIVE;
407 }
408 
SendAddChannelRequestForTesting(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)409 void WebSocketChannel::SendAddChannelRequestForTesting(
410     const GURL& socket_url,
411     const std::vector<std::string>& requested_subprotocols,
412     const url::Origin& origin,
413     const SiteForCookies& site_for_cookies,
414     bool has_storage_access,
415     const IsolationInfo& isolation_info,
416     const HttpRequestHeaders& additional_headers,
417     NetworkTrafficAnnotationTag traffic_annotation,
418     WebSocketStreamRequestCreationCallback callback) {
419   SendAddChannelRequestWithSuppliedCallback(
420       socket_url, requested_subprotocols, origin, site_for_cookies,
421       has_storage_access, isolation_info, additional_headers,
422       traffic_annotation, std::move(callback));
423 }
424 
SetClosingHandshakeTimeoutForTesting(base::TimeDelta delay)425 void WebSocketChannel::SetClosingHandshakeTimeoutForTesting(
426     base::TimeDelta delay) {
427   closing_handshake_timeout_ = delay;
428 }
429 
SetUnderlyingConnectionCloseTimeoutForTesting(base::TimeDelta delay)430 void WebSocketChannel::SetUnderlyingConnectionCloseTimeoutForTesting(
431     base::TimeDelta delay) {
432   underlying_connection_close_timeout_ = delay;
433 }
434 
SendAddChannelRequestWithSuppliedCallback(const GURL & socket_url,const std::vector<std::string> & requested_subprotocols,const url::Origin & origin,const SiteForCookies & site_for_cookies,bool has_storage_access,const IsolationInfo & isolation_info,const HttpRequestHeaders & additional_headers,NetworkTrafficAnnotationTag traffic_annotation,WebSocketStreamRequestCreationCallback callback)435 void WebSocketChannel::SendAddChannelRequestWithSuppliedCallback(
436     const GURL& socket_url,
437     const std::vector<std::string>& requested_subprotocols,
438     const url::Origin& origin,
439     const SiteForCookies& site_for_cookies,
440     bool has_storage_access,
441     const IsolationInfo& isolation_info,
442     const HttpRequestHeaders& additional_headers,
443     NetworkTrafficAnnotationTag traffic_annotation,
444     WebSocketStreamRequestCreationCallback callback) {
445   DCHECK_EQ(FRESHLY_CONSTRUCTED, state_);
446   if (!socket_url.SchemeIsWSOrWSS()) {
447     // TODO(ricea): Kill the renderer (this error should have been caught by
448     // Javascript).
449     event_interface_->OnFailChannel("Invalid scheme", ERR_FAILED, std::nullopt);
450     // |this| is deleted here.
451     return;
452   }
453   socket_url_ = socket_url;
454   auto connect_delegate = std::make_unique<ConnectDelegate>(this);
455   stream_request_ = std::move(callback).Run(
456       socket_url_, requested_subprotocols, origin, site_for_cookies,
457       has_storage_access, isolation_info, additional_headers,
458       url_request_context_.get(), NetLogWithSource(), traffic_annotation,
459       std::move(connect_delegate));
460   SetState(CONNECTING);
461 }
462 
OnCreateURLRequest(URLRequest * request)463 void WebSocketChannel::OnCreateURLRequest(URLRequest* request) {
464   event_interface_->OnCreateURLRequest(request);
465 }
466 
OnURLRequestConnected(URLRequest * request,const TransportInfo & info)467 void WebSocketChannel::OnURLRequestConnected(URLRequest* request,
468                                              const TransportInfo& info) {
469   event_interface_->OnURLRequestConnected(request, info);
470 }
471 
OnConnectSuccess(std::unique_ptr<WebSocketStream> stream,std::unique_ptr<WebSocketHandshakeResponseInfo> response)472 void WebSocketChannel::OnConnectSuccess(
473     std::unique_ptr<WebSocketStream> stream,
474     std::unique_ptr<WebSocketHandshakeResponseInfo> response) {
475   DCHECK(stream);
476   DCHECK_EQ(CONNECTING, state_);
477 
478   stream_ = std::move(stream);
479 
480   SetState(CONNECTED);
481 
482   // |stream_request_| is not used once the connection has succeeded.
483   stream_request_.reset();
484 
485   event_interface_->OnAddChannelResponse(
486       std::move(response), stream_->GetSubProtocol(), stream_->GetExtensions());
487   // |this| may have been deleted after OnAddChannelResponse.
488 }
489 
OnConnectFailure(const std::string & message,int net_error,std::optional<int> response_code)490 void WebSocketChannel::OnConnectFailure(const std::string& message,
491                                         int net_error,
492                                         std::optional<int> response_code) {
493   DCHECK_EQ(CONNECTING, state_);
494 
495   // Copy the message before we delete its owner.
496   std::string message_copy = message;
497 
498   SetState(CLOSED);
499   stream_request_.reset();
500 
501   event_interface_->OnFailChannel(message_copy, net_error, response_code);
502   // |this| has been deleted.
503 }
504 
OnSSLCertificateError(std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks,int net_error,const SSLInfo & ssl_info,bool fatal)505 void WebSocketChannel::OnSSLCertificateError(
506     std::unique_ptr<WebSocketEventInterface::SSLErrorCallbacks>
507         ssl_error_callbacks,
508     int net_error,
509     const SSLInfo& ssl_info,
510     bool fatal) {
511   event_interface_->OnSSLCertificateError(
512       std::move(ssl_error_callbacks), socket_url_, net_error, ssl_info, fatal);
513 }
514 
OnAuthRequired(const AuthChallengeInfo & auth_info,scoped_refptr<HttpResponseHeaders> response_headers,const IPEndPoint & remote_endpoint,base::OnceCallback<void (const AuthCredentials *)> callback,std::optional<AuthCredentials> * credentials)515 int WebSocketChannel::OnAuthRequired(
516     const AuthChallengeInfo& auth_info,
517     scoped_refptr<HttpResponseHeaders> response_headers,
518     const IPEndPoint& remote_endpoint,
519     base::OnceCallback<void(const AuthCredentials*)> callback,
520     std::optional<AuthCredentials>* credentials) {
521   return event_interface_->OnAuthRequired(
522       auth_info, std::move(response_headers), remote_endpoint,
523       std::move(callback), credentials);
524 }
525 
OnStartOpeningHandshake(std::unique_ptr<WebSocketHandshakeRequestInfo> request)526 void WebSocketChannel::OnStartOpeningHandshake(
527     std::unique_ptr<WebSocketHandshakeRequestInfo> request) {
528   event_interface_->OnStartOpeningHandshake(std::move(request));
529 }
530 
WriteFrames()531 ChannelState WebSocketChannel::WriteFrames() {
532   int result = OK;
533   do {
534     // This use of base::Unretained is safe because this object owns the
535     // WebSocketStream and destroying it cancels all callbacks.
536     result = stream_->WriteFrames(
537         data_being_sent_->frames(),
538         base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnWriteDone),
539                        base::Unretained(this), false));
540     if (result != ERR_IO_PENDING) {
541       if (OnWriteDone(true, result) == CHANNEL_DELETED)
542         return CHANNEL_DELETED;
543       // OnWriteDone() returns CHANNEL_DELETED on error. Here |state_| is
544       // guaranteed to be the same as before OnWriteDone() call.
545     }
546   } while (result == OK && data_being_sent_);
547   return CHANNEL_ALIVE;
548 }
549 
OnWriteDone(bool synchronous,int result)550 ChannelState WebSocketChannel::OnWriteDone(bool synchronous, int result) {
551   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
552   DCHECK_NE(CONNECTING, state_);
553   DCHECK_NE(ERR_IO_PENDING, result);
554   DCHECK(data_being_sent_);
555   switch (result) {
556     case OK:
557       if (data_to_send_next_) {
558         data_being_sent_ = std::move(data_to_send_next_);
559         if (!synchronous)
560           return WriteFrames();
561       } else {
562         data_being_sent_.reset();
563         event_interface_->OnSendDataFrameDone();
564       }
565       return CHANNEL_ALIVE;
566 
567     // If a recoverable error condition existed, it would go here.
568 
569     default:
570       DCHECK_LT(result, 0)
571           << "WriteFrames() should only return OK or ERR_ codes";
572 
573       stream_->Close();
574       SetState(CLOSED);
575       DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
576       return CHANNEL_DELETED;
577   }
578 }
579 
ReadFrames()580 ChannelState WebSocketChannel::ReadFrames() {
581   DCHECK(stream_);
582   DCHECK(state_ == CONNECTED || state_ == SEND_CLOSED || state_ == CLOSE_WAIT);
583   DCHECK(read_frames_.empty());
584   if (is_reading_) {
585     return CHANNEL_ALIVE;
586   }
587 
588   if (!InClosingState() && has_received_close_frame_) {
589     DCHECK(!event_interface_->HasPendingDataFrames());
590     // We've been waiting for the client to consume the frames before
591     // responding to the closing handshake initiated by the server.
592     if (RespondToClosingHandshake() == CHANNEL_DELETED) {
593       return CHANNEL_DELETED;
594     }
595   }
596 
597   // TODO(crbug.com/999235): Remove this CHECK.
598   CHECK(event_interface_);
599   while (!event_interface_->HasPendingDataFrames()) {
600     DCHECK(stream_);
601     // This use of base::Unretained is safe because this object owns the
602     // WebSocketStream, and any pending reads will be cancelled when it is
603     // destroyed.
604     const int result = stream_->ReadFrames(
605         &read_frames_,
606         base::BindOnce(base::IgnoreResult(&WebSocketChannel::OnReadDone),
607                        base::Unretained(this), false));
608     if (result == ERR_IO_PENDING) {
609       is_reading_ = true;
610       return CHANNEL_ALIVE;
611     }
612     if (OnReadDone(true, result) == CHANNEL_DELETED) {
613       return CHANNEL_DELETED;
614     }
615     DCHECK_NE(CLOSED, state_);
616     // TODO(crbug.com/999235): Remove this CHECK.
617     CHECK(event_interface_);
618   }
619   return CHANNEL_ALIVE;
620 }
621 
OnReadDone(bool synchronous,int result)622 ChannelState WebSocketChannel::OnReadDone(bool synchronous, int result) {
623   DVLOG(3) << "WebSocketChannel::OnReadDone synchronous?" << synchronous
624            << ", result=" << result
625            << ", read_frames_.size=" << read_frames_.size();
626   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
627   DCHECK_NE(CONNECTING, state_);
628   DCHECK_NE(ERR_IO_PENDING, result);
629   switch (result) {
630     case OK:
631       // ReadFrames() must use ERR_CONNECTION_CLOSED for a closed connection
632       // with no data read, not an empty response.
633       DCHECK(!read_frames_.empty())
634           << "ReadFrames() returned OK, but nothing was read.";
635       for (auto& read_frame : read_frames_) {
636         if (HandleFrame(std::move(read_frame)) == CHANNEL_DELETED)
637           return CHANNEL_DELETED;
638       }
639       read_frames_.clear();
640       DCHECK_NE(CLOSED, state_);
641       if (!synchronous) {
642         is_reading_ = false;
643         if (!event_interface_->HasPendingDataFrames()) {
644           return ReadFrames();
645         }
646       }
647       return CHANNEL_ALIVE;
648 
649     case ERR_WS_PROTOCOL_ERROR:
650       // This could be kWebSocketErrorProtocolError (specifically, non-minimal
651       // encoding of payload length) or kWebSocketErrorMessageTooBig, or an
652       // extension-specific error.
653       FailChannel("Invalid frame header", kWebSocketErrorProtocolError,
654                   "WebSocket Protocol Error");
655       return CHANNEL_DELETED;
656 
657     default:
658       DCHECK_LT(result, 0)
659           << "ReadFrames() should only return OK or ERR_ codes";
660 
661       stream_->Close();
662       SetState(CLOSED);
663 
664       uint16_t code = kWebSocketErrorAbnormalClosure;
665       std::string reason = "";
666       bool was_clean = false;
667       if (has_received_close_frame_) {
668         code = received_close_code_;
669         reason = received_close_reason_;
670         was_clean = (result == ERR_CONNECTION_CLOSED);
671       }
672 
673       DoDropChannel(was_clean, code, reason);
674       return CHANNEL_DELETED;
675   }
676 }
677 
HandleFrame(std::unique_ptr<WebSocketFrame> frame)678 ChannelState WebSocketChannel::HandleFrame(
679     std::unique_ptr<WebSocketFrame> frame) {
680   if (frame->header.masked) {
681     // RFC6455 Section 5.1 "A client MUST close a connection if it detects a
682     // masked frame."
683     FailChannel(
684         "A server must not mask any frames that it sends to the "
685         "client.",
686         kWebSocketErrorProtocolError, "Masked frame from server");
687     return CHANNEL_DELETED;
688   }
689   const WebSocketFrameHeader::OpCode opcode = frame->header.opcode;
690   DCHECK(!WebSocketFrameHeader::IsKnownControlOpCode(opcode) ||
691          frame->header.final);
692   if (frame->header.reserved1 || frame->header.reserved2 ||
693       frame->header.reserved3) {
694     FailChannel(
695         base::StringPrintf("One or more reserved bits are on: reserved1 = %d, "
696                            "reserved2 = %d, reserved3 = %d",
697                            static_cast<int>(frame->header.reserved1),
698                            static_cast<int>(frame->header.reserved2),
699                            static_cast<int>(frame->header.reserved3)),
700         kWebSocketErrorProtocolError, "Invalid reserved bit");
701     return CHANNEL_DELETED;
702   }
703 
704   // Respond to the frame appropriately to its type.
705   return HandleFrameByState(
706       opcode, frame->header.final,
707       base::make_span(frame->payload, base::checked_cast<size_t>(
708                                           frame->header.payload_length)));
709 }
710 
HandleFrameByState(const WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)711 ChannelState WebSocketChannel::HandleFrameByState(
712     const WebSocketFrameHeader::OpCode opcode,
713     bool final,
714     base::span<const char> payload) {
715   DCHECK_NE(RECV_CLOSED, state_)
716       << "HandleFrame() does not support being called re-entrantly from within "
717          "SendClose()";
718   DCHECK_NE(CLOSED, state_);
719   if (state_ == CLOSE_WAIT) {
720     std::string frame_name;
721     GetFrameTypeForOpcode(opcode, &frame_name);
722 
723     // FailChannel() won't send another Close frame.
724     FailChannel(frame_name + " received after close",
725                 kWebSocketErrorProtocolError, "");
726     return CHANNEL_DELETED;
727   }
728   switch (opcode) {
729     case WebSocketFrameHeader::kOpCodeText:  // fall-thru
730     case WebSocketFrameHeader::kOpCodeBinary:
731     case WebSocketFrameHeader::kOpCodeContinuation:
732       return HandleDataFrame(opcode, final, std::move(payload));
733 
734     case WebSocketFrameHeader::kOpCodePing:
735       DVLOG(1) << "Got Ping of size " << payload.size();
736       if (state_ == CONNECTED) {
737         auto buffer = base::MakeRefCounted<IOBufferWithSize>(payload.size());
738         base::ranges::copy(payload, buffer->data());
739         return SendFrameInternal(true, WebSocketFrameHeader::kOpCodePong,
740                                  std::move(buffer), payload.size());
741       }
742       DVLOG(3) << "Ignored ping in state " << state_;
743       return CHANNEL_ALIVE;
744 
745     case WebSocketFrameHeader::kOpCodePong:
746       DVLOG(1) << "Got Pong of size " << payload.size();
747       // There is no need to do anything with pong messages.
748       return CHANNEL_ALIVE;
749 
750     case WebSocketFrameHeader::kOpCodeClose: {
751       uint16_t code = kWebSocketNormalClosure;
752       std::string reason;
753       std::string message;
754       if (!ParseClose(payload, &code, &reason, &message)) {
755         FailChannel(message, code, reason);
756         return CHANNEL_DELETED;
757       }
758       // TODO(ricea): Find a way to safely log the message from the close
759       // message (escape control codes and so on).
760       return HandleCloseFrame(code, reason);
761     }
762 
763     default:
764       FailChannel(base::StringPrintf("Unrecognized frame opcode: %d", opcode),
765                   kWebSocketErrorProtocolError, "Unknown opcode");
766       return CHANNEL_DELETED;
767   }
768 }
769 
HandleDataFrame(WebSocketFrameHeader::OpCode opcode,bool final,base::span<const char> payload)770 ChannelState WebSocketChannel::HandleDataFrame(
771     WebSocketFrameHeader::OpCode opcode,
772     bool final,
773     base::span<const char> payload) {
774   DVLOG(3) << "WebSocketChannel::HandleDataFrame opcode=" << opcode
775            << ", final?" << final << ", data=" << (void*)payload.data()
776            << ", size=" << payload.size();
777   if (state_ != CONNECTED) {
778     DVLOG(3) << "Ignored data packet received in state " << state_;
779     return CHANNEL_ALIVE;
780   }
781   if (has_received_close_frame_) {
782     DVLOG(3) << "Ignored data packet as we've received a close frame.";
783     return CHANNEL_ALIVE;
784   }
785   DCHECK(opcode == WebSocketFrameHeader::kOpCodeContinuation ||
786          opcode == WebSocketFrameHeader::kOpCodeText ||
787          opcode == WebSocketFrameHeader::kOpCodeBinary);
788   const bool got_continuation =
789       (opcode == WebSocketFrameHeader::kOpCodeContinuation);
790   if (got_continuation != expecting_to_handle_continuation_) {
791     const std::string console_log = got_continuation
792         ? "Received unexpected continuation frame."
793         : "Received start of new message but previous message is unfinished.";
794     const std::string reason = got_continuation
795         ? "Unexpected continuation"
796         : "Previous data frame unfinished";
797     FailChannel(console_log, kWebSocketErrorProtocolError, reason);
798     return CHANNEL_DELETED;
799   }
800   expecting_to_handle_continuation_ = !final;
801   WebSocketFrameHeader::OpCode opcode_to_send = opcode;
802   if (!initial_frame_forwarded_ &&
803       opcode == WebSocketFrameHeader::kOpCodeContinuation) {
804     opcode_to_send = receiving_text_message_
805                          ? WebSocketFrameHeader::kOpCodeText
806                          : WebSocketFrameHeader::kOpCodeBinary;
807   }
808   if (opcode == WebSocketFrameHeader::kOpCodeText ||
809       (opcode == WebSocketFrameHeader::kOpCodeContinuation &&
810        receiving_text_message_)) {
811     // This call is not redundant when size == 0 because it tells us what
812     // the current state is.
813     StreamingUtf8Validator::State state =
814         incoming_utf8_validator_.AddBytes(base::as_byte_span(payload));
815     if (state == StreamingUtf8Validator::INVALID ||
816         (state == StreamingUtf8Validator::VALID_MIDPOINT && final)) {
817       FailChannel("Could not decode a text frame as UTF-8.",
818                   kWebSocketErrorProtocolError, "Invalid UTF-8 in text frame");
819       return CHANNEL_DELETED;
820     }
821     receiving_text_message_ = !final;
822     DCHECK(!final || state == StreamingUtf8Validator::VALID_ENDPOINT);
823   }
824   if (payload.size() == 0U && !final)
825     return CHANNEL_ALIVE;
826 
827   initial_frame_forwarded_ = !final;
828   // Sends the received frame to the renderer process.
829   event_interface_->OnDataFrame(final, opcode_to_send, payload);
830   return CHANNEL_ALIVE;
831 }
832 
HandleCloseFrame(uint16_t code,const std::string & reason)833 ChannelState WebSocketChannel::HandleCloseFrame(uint16_t code,
834                                                 const std::string& reason) {
835   DVLOG(1) << "Got Close with code " << code;
836   switch (state_) {
837     case CONNECTED:
838       has_received_close_frame_ = true;
839       received_close_code_ = code;
840       received_close_reason_ = reason;
841       if (event_interface_->HasPendingDataFrames()) {
842         // We have some data to be sent to the renderer before sending this
843         // frame.
844         return CHANNEL_ALIVE;
845       }
846       return RespondToClosingHandshake();
847 
848     case SEND_CLOSED:
849       SetState(CLOSE_WAIT);
850       DCHECK(close_timer_.IsRunning());
851       close_timer_.Stop();
852       // This use of base::Unretained() is safe because we stop the timer
853       // in the destructor.
854       close_timer_.Start(FROM_HERE, underlying_connection_close_timeout_,
855                          base::BindOnce(&WebSocketChannel::CloseTimeout,
856                                         base::Unretained(this)));
857 
858       // From RFC6455 section 7.1.5: "Each endpoint
859       // will see the status code sent by the other end as _The WebSocket
860       // Connection Close Code_."
861       has_received_close_frame_ = true;
862       received_close_code_ = code;
863       received_close_reason_ = reason;
864       break;
865 
866     default:
867       LOG(DFATAL) << "Got Close in unexpected state " << state_;
868       break;
869   }
870   return CHANNEL_ALIVE;
871 }
872 
RespondToClosingHandshake()873 ChannelState WebSocketChannel::RespondToClosingHandshake() {
874   DCHECK(has_received_close_frame_);
875   DCHECK_EQ(CONNECTED, state_);
876   SetState(RECV_CLOSED);
877   if (SendClose(received_close_code_, received_close_reason_) ==
878       CHANNEL_DELETED)
879     return CHANNEL_DELETED;
880   DCHECK_EQ(RECV_CLOSED, state_);
881 
882   SetState(CLOSE_WAIT);
883   DCHECK(!close_timer_.IsRunning());
884   // This use of base::Unretained() is safe because we stop the timer
885   // in the destructor.
886   close_timer_.Start(
887       FROM_HERE, underlying_connection_close_timeout_,
888       base::BindOnce(&WebSocketChannel::CloseTimeout, base::Unretained(this)));
889 
890   event_interface_->OnClosingHandshake();
891   return CHANNEL_ALIVE;
892 }
893 
SendFrameInternal(bool fin,WebSocketFrameHeader::OpCode op_code,scoped_refptr<IOBuffer> buffer,uint64_t buffer_size)894 ChannelState WebSocketChannel::SendFrameInternal(
895     bool fin,
896     WebSocketFrameHeader::OpCode op_code,
897     scoped_refptr<IOBuffer> buffer,
898     uint64_t buffer_size) {
899   DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
900   DCHECK(stream_);
901 
902   auto frame = std::make_unique<WebSocketFrame>(op_code);
903   WebSocketFrameHeader& header = frame->header;
904   header.final = fin;
905   header.masked = true;
906   header.payload_length = buffer_size;
907   frame->payload = buffer->data();
908 
909   if (data_being_sent_) {
910     // Either the link to the WebSocket server is saturated, or several messages
911     // are being sent in a batch.
912     if (!data_to_send_next_)
913       data_to_send_next_ = std::make_unique<SendBuffer>();
914     data_to_send_next_->AddFrame(std::move(frame), std::move(buffer));
915     return CHANNEL_ALIVE;
916   }
917 
918   data_being_sent_ = std::make_unique<SendBuffer>();
919   data_being_sent_->AddFrame(std::move(frame), std::move(buffer));
920   return WriteFrames();
921 }
922 
FailChannel(const std::string & message,uint16_t code,const std::string & reason)923 void WebSocketChannel::FailChannel(const std::string& message,
924                                    uint16_t code,
925                                    const std::string& reason) {
926   DCHECK_NE(FRESHLY_CONSTRUCTED, state_);
927   DCHECK_NE(CONNECTING, state_);
928   DCHECK_NE(CLOSED, state_);
929 
930   stream_->GetNetLogWithSource().AddEvent(
931       net::NetLogEventType::WEBSOCKET_INVALID_FRAME,
932       [&] { return NetLogFailParam(code, reason, message); });
933 
934   if (state_ == CONNECTED) {
935     if (SendClose(code, reason) == CHANNEL_DELETED)
936       return;
937   }
938 
939   // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates the browser
940   // should close the connection itself without waiting for the closing
941   // handshake.
942   stream_->Close();
943   SetState(CLOSED);
944   event_interface_->OnFailChannel(message, ERR_FAILED, std::nullopt);
945 }
946 
SendClose(uint16_t code,const std::string & reason)947 ChannelState WebSocketChannel::SendClose(uint16_t code,
948                                          const std::string& reason) {
949   DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED);
950   DCHECK_LE(reason.size(), kMaximumCloseReasonLength);
951   scoped_refptr<IOBuffer> body;
952   uint64_t size = 0;
953   if (code == kWebSocketErrorNoStatusReceived) {
954     // Special case: translate kWebSocketErrorNoStatusReceived into a Close
955     // frame with no payload.
956     DCHECK(reason.empty());
957     body = base::MakeRefCounted<IOBufferWithSize>();
958   } else {
959     const size_t payload_length = kWebSocketCloseCodeLength + reason.length();
960     body = base::MakeRefCounted<IOBufferWithSize>(payload_length);
961     size = payload_length;
962     auto [code_span, body_span] =
963         body->span().split_at<kWebSocketCloseCodeLength>();
964     base::as_writable_bytes(code_span).copy_from(base::U16ToBigEndian(code));
965     static_assert(sizeof(code) == kWebSocketCloseCodeLength,
966                   "they should both be two");
967     body_span.copy_from(reason);
968   }
969 
970   return SendFrameInternal(true, WebSocketFrameHeader::kOpCodeClose,
971                            std::move(body), size);
972 }
973 
ParseClose(base::span<const char> payload,uint16_t * code,std::string * reason,std::string * message)974 bool WebSocketChannel::ParseClose(base::span<const char> payload,
975                                   uint16_t* code,
976                                   std::string* reason,
977                                   std::string* message) {
978   const uint64_t size = static_cast<uint64_t>(payload.size());
979   reason->clear();
980   if (size < kWebSocketCloseCodeLength) {
981     if (size == 0U) {
982       *code = kWebSocketErrorNoStatusReceived;
983       return true;
984     }
985 
986     DVLOG(1) << "Close frame with payload size " << size << " received "
987              << "(the first byte is " << std::hex
988              << static_cast<int>(payload[0]) << ")";
989     *code = kWebSocketErrorProtocolError;
990     *message =
991         "Received a broken close frame containing an invalid size body.";
992     return false;
993   }
994 
995   const char* data = payload.data();
996   uint16_t unchecked_code =
997       base::U16FromBigEndian(base::as_byte_span(payload).first<2>());
998   static_assert(sizeof(unchecked_code) == kWebSocketCloseCodeLength,
999                 "they should both be two bytes");
1000 
1001   switch (unchecked_code) {
1002     case kWebSocketErrorNoStatusReceived:
1003     case kWebSocketErrorAbnormalClosure:
1004     case kWebSocketErrorTlsHandshake:
1005       *code = kWebSocketErrorProtocolError;
1006       *message =
1007           "Received a broken close frame containing a reserved status code.";
1008       return false;
1009 
1010     default:
1011       *code = unchecked_code;
1012       break;
1013   }
1014 
1015   std::string text(data + kWebSocketCloseCodeLength, data + size);
1016   if (StreamingUtf8Validator::Validate(text)) {
1017     reason->swap(text);
1018     return true;
1019   }
1020 
1021   *code = kWebSocketErrorProtocolError;
1022   *reason = "Invalid UTF-8 in Close frame";
1023   *message = "Received a broken close frame containing invalid UTF-8.";
1024   return false;
1025 }
1026 
DoDropChannel(bool was_clean,uint16_t code,const std::string & reason)1027 void WebSocketChannel::DoDropChannel(bool was_clean,
1028                                      uint16_t code,
1029                                      const std::string& reason) {
1030   event_interface_->OnDropChannel(was_clean, code, reason);
1031 }
1032 
CloseTimeout()1033 void WebSocketChannel::CloseTimeout() {
1034   stream_->GetNetLogWithSource().AddEvent(
1035       net::NetLogEventType::WEBSOCKET_CLOSE_TIMEOUT);
1036   stream_->Close();
1037   SetState(CLOSED);
1038   if (has_received_close_frame_) {
1039     DoDropChannel(true, received_close_code_, received_close_reason_);
1040   } else {
1041     DoDropChannel(false, kWebSocketErrorAbnormalClosure, "");
1042   }
1043   // |this| has been deleted.
1044 }
1045 
1046 }  // namespace net
1047