1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/websockets/websocket_handshake_stream_base.h"
6
7 #include <stddef.h>
8
9 #include <unordered_set>
10
11 #include "base/metrics/histogram_macros.h"
12 #include "base/strings/strcat.h"
13 #include "base/strings/string_util.h"
14 #include "net/http/http_request_headers.h"
15 #include "net/http/http_response_headers.h"
16 #include "net/websockets/websocket_extension.h"
17 #include "net/websockets/websocket_extension_parser.h"
18 #include "net/websockets/websocket_handshake_constants.h"
19
20 namespace net {
21
22 // static
MultipleHeaderValuesMessage(const std::string & header_name)23 std::string WebSocketHandshakeStreamBase::MultipleHeaderValuesMessage(
24 const std::string& header_name) {
25 return base::StrCat(
26 {"'", header_name,
27 "' header must not appear more than once in a response"});
28 }
29
30 // static
AddVectorHeaderIfNonEmpty(const char * name,const std::vector<std::string> & value,HttpRequestHeaders * headers)31 void WebSocketHandshakeStreamBase::AddVectorHeaderIfNonEmpty(
32 const char* name,
33 const std::vector<std::string>& value,
34 HttpRequestHeaders* headers) {
35 if (value.empty())
36 return;
37 headers->SetHeader(name, base::JoinString(value, ", "));
38 }
39
40 // static
ValidateSubProtocol(const HttpResponseHeaders * headers,const std::vector<std::string> & requested_sub_protocols,std::string * sub_protocol,std::string * failure_message)41 bool WebSocketHandshakeStreamBase::ValidateSubProtocol(
42 const HttpResponseHeaders* headers,
43 const std::vector<std::string>& requested_sub_protocols,
44 std::string* sub_protocol,
45 std::string* failure_message) {
46 size_t iter = 0;
47 std::string value;
48 std::unordered_set<std::string> requested_set(requested_sub_protocols.begin(),
49 requested_sub_protocols.end());
50 int count = 0;
51 bool has_multiple_protocols = false;
52 bool has_invalid_protocol = false;
53
54 while (!has_invalid_protocol || !has_multiple_protocols) {
55 std::string temp_value;
56 if (!headers->EnumerateHeader(&iter, websockets::kSecWebSocketProtocol,
57 &temp_value))
58 break;
59 value = temp_value;
60 if (requested_set.count(value) == 0)
61 has_invalid_protocol = true;
62 if (++count > 1)
63 has_multiple_protocols = true;
64 }
65
66 if (has_multiple_protocols) {
67 *failure_message =
68 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
69 return false;
70 } else if (count > 0 && requested_sub_protocols.size() == 0) {
71 *failure_message =
72 base::StrCat({"Response must not include 'Sec-WebSocket-Protocol' "
73 "header if not present in request: ",
74 value});
75 return false;
76 } else if (has_invalid_protocol) {
77 *failure_message = "'Sec-WebSocket-Protocol' header value '" + value +
78 "' in response does not match any of sent values";
79 return false;
80 } else if (requested_sub_protocols.size() > 0 && count == 0) {
81 *failure_message =
82 "Sent non-empty 'Sec-WebSocket-Protocol' header "
83 "but no response was received";
84 return false;
85 }
86 *sub_protocol = value;
87 return true;
88 }
89
90 // static
ValidateExtensions(const HttpResponseHeaders * headers,std::string * accepted_extensions_descriptor,std::string * failure_message,WebSocketExtensionParams * params)91 bool WebSocketHandshakeStreamBase::ValidateExtensions(
92 const HttpResponseHeaders* headers,
93 std::string* accepted_extensions_descriptor,
94 std::string* failure_message,
95 WebSocketExtensionParams* params) {
96 size_t iter = 0;
97 std::string header_value;
98 std::vector<std::string> header_values;
99 // TODO(ricea): If adding support for additional extensions, generalise this
100 // code.
101 bool seen_permessage_deflate = false;
102 while (headers->EnumerateHeader(&iter, websockets::kSecWebSocketExtensions,
103 &header_value)) {
104 WebSocketExtensionParser parser;
105 if (!parser.Parse(header_value)) {
106 // TODO(yhirano) Set appropriate failure message.
107 *failure_message =
108 "'Sec-WebSocket-Extensions' header value is "
109 "rejected by the parser: " +
110 header_value;
111 return false;
112 }
113
114 const std::vector<WebSocketExtension>& extensions = parser.extensions();
115 for (const auto& extension : extensions) {
116 if (extension.name() == "permessage-deflate") {
117 if (seen_permessage_deflate) {
118 *failure_message = "Received duplicate permessage-deflate response";
119 return false;
120 }
121 seen_permessage_deflate = true;
122 auto& deflate_parameters = params->deflate_parameters;
123 if (!deflate_parameters.Initialize(extension, failure_message) ||
124 !deflate_parameters.IsValidAsResponse(failure_message)) {
125 *failure_message = "Error in permessage-deflate: " + *failure_message;
126 return false;
127 }
128 // Note that we don't have to check the request-response compatibility
129 // here because we send a request compatible with any valid responses.
130 // TODO(yhirano): Place a DCHECK here.
131
132 header_values.push_back(header_value);
133 } else {
134 *failure_message = "Found an unsupported extension '" +
135 extension.name() +
136 "' in 'Sec-WebSocket-Extensions' header";
137 return false;
138 }
139 }
140 }
141 *accepted_extensions_descriptor = base::JoinString(header_values, ", ");
142 params->deflate_enabled = seen_permessage_deflate;
143 return true;
144 }
145
RecordHandshakeResult(HandshakeResult result)146 void WebSocketHandshakeStreamBase::RecordHandshakeResult(
147 HandshakeResult result) {
148 UMA_HISTOGRAM_ENUMERATION("Net.WebSocket.HandshakeResult2", result,
149 HandshakeResult::NUM_HANDSHAKE_RESULT_TYPES);
150 }
151
152 } // namespace net
153