xref: /aosp_15_r20/external/cronet/net/websockets/websocket_basic_handshake_stream_test.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2018 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/websockets/websocket_basic_handshake_stream.h"
6 
7 #include <set>
8 #include <string>
9 #include <utility>
10 #include <vector>
11 
12 #include "base/containers/span.h"
13 #include "net/base/address_list.h"
14 #include "net/base/ip_address.h"
15 #include "net/base/ip_endpoint.h"
16 #include "net/base/net_errors.h"
17 #include "net/base/test_completion_callback.h"
18 #include "net/http/http_request_info.h"
19 #include "net/http/http_response_info.h"
20 #include "net/log/net_log_with_source.h"
21 #include "net/socket/client_socket_handle.h"
22 #include "net/socket/socket_test_util.h"
23 #include "net/socket/stream_socket.h"
24 #include "net/socket/websocket_endpoint_lock_manager.h"
25 #include "net/traffic_annotation/network_traffic_annotation.h"
26 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
27 #include "net/websockets/websocket_test_util.h"
28 #include "testing/gmock/include/gmock/gmock.h"
29 #include "testing/gtest/include/gtest/gtest.h"
30 #include "url/gurl.h"
31 #include "url/origin.h"
32 
33 namespace net {
34 namespace {
35 
TEST(WebSocketBasicHandshakeStreamTest,ConnectionClosedOnFailure)36 TEST(WebSocketBasicHandshakeStreamTest, ConnectionClosedOnFailure) {
37   std::string request = WebSocketStandardRequest(
38       "/", "www.example.org",
39       url::Origin::Create(GURL("http://origin.example.org")),
40       /*send_additional_request_headers=*/{}, /*extra_headers=*/{});
41   std::string response =
42       "HTTP/1.1 404 Not Found\r\n"
43       "Content-Length: 0\r\n"
44       "\r\n";
45   MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, request.c_str())};
46   MockRead reads[] = {MockRead(SYNCHRONOUS, 1, response.c_str()),
47                       MockRead(SYNCHRONOUS, ERR_IO_PENDING, 2)};
48   IPEndPoint end_point(IPAddress(127, 0, 0, 1), 80);
49   SequencedSocketData sequenced_socket_data(
50       MockConnect(SYNCHRONOUS, OK, end_point), reads, writes);
51   auto socket = std::make_unique<MockTCPClientSocket>(
52       AddressList(end_point), nullptr, &sequenced_socket_data);
53   const int connect_result = socket->Connect(CompletionOnceCallback());
54   EXPECT_EQ(connect_result, OK);
55   const MockTCPClientSocket* const socket_ptr = socket.get();
56   auto handle = std::make_unique<ClientSocketHandle>();
57   handle->SetSocket(std::move(socket));
58   DummyConnectDelegate delegate;
59   WebSocketEndpointLockManager endpoint_lock_manager;
60   TestWebSocketStreamRequestAPI stream_request_api;
61   std::vector<std::string> extensions = {
62       "permessage-deflate; client_max_window_bits"};
63   WebSocketBasicHandshakeStream basic_handshake_stream(
64       std::move(handle), &delegate, false, {}, extensions, &stream_request_api,
65       &endpoint_lock_manager);
66   basic_handshake_stream.SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
67   HttpRequestInfo request_info;
68   request_info.url = GURL("ws://www.example.com/");
69   request_info.method = "GET";
70   request_info.traffic_annotation =
71       MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
72   TestCompletionCallback callback1;
73   NetLogWithSource net_log;
74   basic_handshake_stream.RegisterRequest(&request_info);
75   const int result1 =
76       callback1.GetResult(basic_handshake_stream.InitializeStream(
77           true, LOWEST, net_log, callback1.callback()));
78   EXPECT_EQ(result1, OK);
79 
80   auto request_headers = WebSocketCommonTestHeaders();
81   HttpResponseInfo response_info;
82   TestCompletionCallback callback2;
83   const int result2 = callback2.GetResult(basic_handshake_stream.SendRequest(
84       request_headers, &response_info, callback2.callback()));
85   EXPECT_EQ(result2, OK);
86 
87   TestCompletionCallback callback3;
88   const int result3 = callback3.GetResult(
89       basic_handshake_stream.ReadResponseHeaders(callback2.callback()));
90   EXPECT_EQ(result3, ERR_INVALID_RESPONSE);
91 
92   EXPECT_FALSE(socket_ptr->IsConnected());
93 }
94 
TEST(WebSocketBasicHandshakeStreamTest,DnsAliasesCanBeAccessed)95 TEST(WebSocketBasicHandshakeStreamTest, DnsAliasesCanBeAccessed) {
96   std::string request = WebSocketStandardRequest(
97       "/", "www.example.org",
98       url::Origin::Create(GURL("http://origin.example.org")),
99       /*send_additional_request_headers=*/{}, /*extra_headers=*/{});
100   std::string response = WebSocketStandardResponse("");
101   MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, request.c_str())};
102   MockRead reads[] = {MockRead(SYNCHRONOUS, 1, response.c_str()),
103                       MockRead(SYNCHRONOUS, ERR_IO_PENDING, 2)};
104 
105   IPEndPoint end_point(IPAddress(127, 0, 0, 1), 80);
106   SequencedSocketData sequenced_socket_data(
107       MockConnect(SYNCHRONOUS, OK, end_point), reads, writes);
108   auto socket = std::make_unique<MockTCPClientSocket>(
109       AddressList(end_point), nullptr, &sequenced_socket_data);
110   const int connect_result = socket->Connect(CompletionOnceCallback());
111   EXPECT_EQ(connect_result, OK);
112 
113   std::set<std::string> aliases({"alias1", "alias2", "www.example.org"});
114   socket->SetDnsAliases(aliases);
115   EXPECT_THAT(
116       socket->GetDnsAliases(),
117       testing::UnorderedElementsAre("alias1", "alias2", "www.example.org"));
118 
119   const MockTCPClientSocket* const socket_ptr = socket.get();
120   auto handle = std::make_unique<ClientSocketHandle>();
121   handle->SetSocket(std::move(socket));
122   EXPECT_THAT(
123       handle->socket()->GetDnsAliases(),
124       testing::UnorderedElementsAre("alias1", "alias2", "www.example.org"));
125 
126   DummyConnectDelegate delegate;
127   WebSocketEndpointLockManager endpoint_lock_manager;
128   TestWebSocketStreamRequestAPI stream_request_api;
129   std::vector<std::string> extensions = {
130       "permessage-deflate; client_max_window_bits"};
131   WebSocketBasicHandshakeStream basic_handshake_stream(
132       std::move(handle), &delegate, false, {}, extensions, &stream_request_api,
133       &endpoint_lock_manager);
134   basic_handshake_stream.SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
135   HttpRequestInfo request_info;
136   request_info.url = GURL("ws://www.example.com/");
137   request_info.method = "GET";
138   request_info.traffic_annotation =
139       MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS);
140   TestCompletionCallback callback1;
141   NetLogWithSource net_log;
142   basic_handshake_stream.RegisterRequest(&request_info);
143   const int result1 =
144       callback1.GetResult(basic_handshake_stream.InitializeStream(
145           true, LOWEST, net_log, callback1.callback()));
146   EXPECT_EQ(result1, OK);
147 
148   auto request_headers = WebSocketCommonTestHeaders();
149   HttpResponseInfo response_info;
150   TestCompletionCallback callback2;
151   const int result2 = callback2.GetResult(basic_handshake_stream.SendRequest(
152       request_headers, &response_info, callback2.callback()));
153   EXPECT_EQ(result2, OK);
154 
155   TestCompletionCallback callback3;
156   const int result3 = callback3.GetResult(
157       basic_handshake_stream.ReadResponseHeaders(callback2.callback()));
158   EXPECT_EQ(result3, OK);
159 
160   EXPECT_TRUE(socket_ptr->IsConnected());
161 
162   EXPECT_THAT(basic_handshake_stream.GetDnsAliases(),
163               testing::ElementsAre("alias1", "alias2", "www.example.org"));
164 }
165 
166 }  // namespace
167 }  // namespace net
168