xref: /aosp_15_r20/external/cronet/net/websockets/websocket_handshake_stream_base.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_handshake_stream_base.h"
6 
7 #include <stddef.h>
8 
9 #include <unordered_set>
10 
11 #include "base/metrics/histogram_macros.h"
12 #include "base/strings/strcat.h"
13 #include "base/strings/string_util.h"
14 #include "net/http/http_request_headers.h"
15 #include "net/http/http_response_headers.h"
16 #include "net/websockets/websocket_extension.h"
17 #include "net/websockets/websocket_extension_parser.h"
18 #include "net/websockets/websocket_handshake_constants.h"
19 
20 namespace net {
21 
22 // static
MultipleHeaderValuesMessage(const std::string & header_name)23 std::string WebSocketHandshakeStreamBase::MultipleHeaderValuesMessage(
24     const std::string& header_name) {
25   return base::StrCat(
26       {"'", header_name,
27        "' header must not appear more than once in a response"});
28 }
29 
30 // static
AddVectorHeaderIfNonEmpty(const char * name,const std::vector<std::string> & value,HttpRequestHeaders * headers)31 void WebSocketHandshakeStreamBase::AddVectorHeaderIfNonEmpty(
32     const char* name,
33     const std::vector<std::string>& value,
34     HttpRequestHeaders* headers) {
35   if (value.empty())
36     return;
37   headers->SetHeader(name, base::JoinString(value, ", "));
38 }
39 
40 // static
ValidateSubProtocol(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_sub_protocols,std::string * sub_protocol,std::string * failure_message)41 bool WebSocketHandshakeStreamBase::ValidateSubProtocol(
42     const HttpResponseHeaders* headers,
43     const std::vector<std::string>& requested_sub_protocols,
44     std::string* sub_protocol,
45     std::string* failure_message) {
46   size_t iter = 0;
47   std::string value;
48   std::unordered_set<std::string> requested_set(requested_sub_protocols.begin(),
49                                                 requested_sub_protocols.end());
50   int count = 0;
51   bool has_multiple_protocols = false;
52   bool has_invalid_protocol = false;
53 
54   while (!has_invalid_protocol || !has_multiple_protocols) {
55     std::string temp_value;
56     if (!headers->EnumerateHeader(&iter, websockets::kSecWebSocketProtocol,
57                                   &temp_value))
58       break;
59     value = temp_value;
60     if (requested_set.count(value) == 0)
61       has_invalid_protocol = true;
62     if (++count > 1)
63       has_multiple_protocols = true;
64   }
65 
66   if (has_multiple_protocols) {
67     *failure_message =
68         MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
69     return false;
70   } else if (count > 0 && requested_sub_protocols.size() == 0) {
71     *failure_message =
72         base::StrCat({"Response must not include 'Sec-WebSocket-Protocol' "
73                       "header if not present in request: ",
74                       value});
75     return false;
76   } else if (has_invalid_protocol) {
77     *failure_message = "'Sec-WebSocket-Protocol' header value '" + value +
78                        "' in response does not match any of sent values";
79     return false;
80   } else if (requested_sub_protocols.size() > 0 && count == 0) {
81     *failure_message =
82         "Sent non-empty 'Sec-WebSocket-Protocol' header "
83         "but no response was received";
84     return false;
85   }
86   *sub_protocol = value;
87   return true;
88 }
89 
90 // static
ValidateExtensions(const HttpResponseHeaders * headers,std::string * accepted_extensions_descriptor,std::string * failure_message,WebSocketExtensionParams * params)91 bool WebSocketHandshakeStreamBase::ValidateExtensions(
92     const HttpResponseHeaders* headers,
93     std::string* accepted_extensions_descriptor,
94     std::string* failure_message,
95     WebSocketExtensionParams* params) {
96   size_t iter = 0;
97   std::string header_value;
98   std::vector<std::string> header_values;
99   // TODO(ricea): If adding support for additional extensions, generalise this
100   // code.
101   bool seen_permessage_deflate = false;
102   while (headers->EnumerateHeader(&iter, websockets::kSecWebSocketExtensions,
103                                   &header_value)) {
104     WebSocketExtensionParser parser;
105     if (!parser.Parse(header_value)) {
106       // TODO(yhirano) Set appropriate failure message.
107       *failure_message =
108           "'Sec-WebSocket-Extensions' header value is "
109           "rejected by the parser: " +
110           header_value;
111       return false;
112     }
113 
114     const std::vector<WebSocketExtension>& extensions = parser.extensions();
115     for (const auto& extension : extensions) {
116       if (extension.name() == "permessage-deflate") {
117         if (seen_permessage_deflate) {
118           *failure_message = "Received duplicate permessage-deflate response";
119           return false;
120         }
121         seen_permessage_deflate = true;
122         auto& deflate_parameters = params->deflate_parameters;
123         if (!deflate_parameters.Initialize(extension, failure_message) ||
124             !deflate_parameters.IsValidAsResponse(failure_message)) {
125           *failure_message = "Error in permessage-deflate: " + *failure_message;
126           return false;
127         }
128         // Note that we don't have to check the request-response compatibility
129         // here because we send a request compatible with any valid responses.
130         // TODO(yhirano): Place a DCHECK here.
131 
132         header_values.push_back(header_value);
133       } else {
134         *failure_message = "Found an unsupported extension '" +
135                            extension.name() +
136                            "' in 'Sec-WebSocket-Extensions' header";
137         return false;
138       }
139     }
140   }
141   *accepted_extensions_descriptor = base::JoinString(header_values, ", ");
142   params->deflate_enabled = seen_permessage_deflate;
143   return true;
144 }
145 
RecordHandshakeResult(HandshakeResult result)146 void WebSocketHandshakeStreamBase::RecordHandshakeResult(
147     HandshakeResult result) {
148   UMA_HISTOGRAM_ENUMERATION("Net.WebSocket.HandshakeResult2", result,
149                             HandshakeResult::NUM_HANDSHAKE_RESULT_TYPES);
150 }
151 
152 }  // namespace net
153