1 // Copyright 2022 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <memory>
6 #include <string>
7
8 #include "net/socket/transport_client_socket_test_util.h"
9
10 #include "base/memory/ref_counted.h"
11 #include "net/base/io_buffer.h"
12 #include "net/base/net_errors.h"
13 #include "net/test/gtest_util.h"
14 #include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
15 #include "testing/gtest/include/gtest/gtest.h"
16
17 namespace net {
18
SendRequestAndResponse(StreamSocket * socket,StreamSocket * connected_socket)19 void SendRequestAndResponse(StreamSocket* socket,
20 StreamSocket* connected_socket) {
21 // Send client request.
22 const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
23 int request_len = strlen(request_text);
24 scoped_refptr<DrainableIOBuffer> request_buffer =
25 base::MakeRefCounted<DrainableIOBuffer>(
26 base::MakeRefCounted<IOBufferWithSize>(request_len), request_len);
27 memcpy(request_buffer->data(), request_text, request_len);
28
29 int bytes_written = 0;
30 while (request_buffer->BytesRemaining() > 0) {
31 TestCompletionCallback write_callback;
32 int write_result =
33 socket->Write(request_buffer.get(), request_buffer->BytesRemaining(),
34 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
35
36 write_result = write_callback.GetResult(write_result);
37 ASSERT_GT(write_result, 0);
38 ASSERT_LE(bytes_written + write_result, request_len);
39 request_buffer->DidConsume(write_result);
40
41 bytes_written += write_result;
42 }
43 ASSERT_EQ(request_len, bytes_written);
44
45 // Confirm that the server receives what client sent.
46 std::string data_received =
47 ReadDataOfExpectedLength(connected_socket, bytes_written);
48 ASSERT_TRUE(connected_socket->IsConnectedAndIdle());
49 ASSERT_EQ(request_text, data_received);
50
51 // Write server response.
52 SendServerResponse(connected_socket);
53 }
54
ReadDataOfExpectedLength(StreamSocket * socket,int expected_bytes_read)55 std::string ReadDataOfExpectedLength(StreamSocket* socket,
56 int expected_bytes_read) {
57 int bytes_read = 0;
58 scoped_refptr<IOBufferWithSize> read_buffer =
59 base::MakeRefCounted<IOBufferWithSize>(expected_bytes_read);
60 while (bytes_read < expected_bytes_read) {
61 TestCompletionCallback read_callback;
62 int rv = socket->Read(read_buffer.get(), expected_bytes_read - bytes_read,
63 read_callback.callback());
64 EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
65 rv = read_callback.GetResult(rv);
66 EXPECT_GE(rv, 0);
67 bytes_read += rv;
68 }
69 EXPECT_EQ(expected_bytes_read, bytes_read);
70 return std::string(read_buffer->data(), bytes_read);
71 }
72
SendServerResponse(StreamSocket * socket)73 void SendServerResponse(StreamSocket* socket) {
74 const char kServerReply[] = "HTTP/1.1 404 Not Found";
75 int reply_len = strlen(kServerReply);
76 scoped_refptr<DrainableIOBuffer> write_buffer =
77 base::MakeRefCounted<DrainableIOBuffer>(
78 base::MakeRefCounted<IOBufferWithSize>(reply_len), reply_len);
79 memcpy(write_buffer->data(), kServerReply, reply_len);
80 int bytes_written = 0;
81 while (write_buffer->BytesRemaining() > 0) {
82 TestCompletionCallback write_callback;
83 int write_result =
84 socket->Write(write_buffer.get(), write_buffer->BytesRemaining(),
85 write_callback.callback(), TRAFFIC_ANNOTATION_FOR_TESTS);
86 write_result = write_callback.GetResult(write_result);
87 ASSERT_GE(write_result, 0);
88 ASSERT_LE(bytes_written + write_result, reply_len);
89 write_buffer->DidConsume(write_result);
90 bytes_written += write_result;
91 }
92 }
93
DrainStreamSocket(StreamSocket * socket,IOBuffer * buf,uint32_t buf_len,uint32_t bytes_to_read,TestCompletionCallback * callback)94 int DrainStreamSocket(StreamSocket* socket,
95 IOBuffer* buf,
96 uint32_t buf_len,
97 uint32_t bytes_to_read,
98 TestCompletionCallback* callback) {
99 int rv = OK;
100 uint32_t bytes_read = 0;
101
102 while (bytes_read < bytes_to_read) {
103 rv = socket->Read(buf, buf_len, callback->callback());
104 EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
105 rv = callback->GetResult(rv);
106 EXPECT_GT(rv, 0);
107 bytes_read += rv;
108 }
109
110 return static_cast<int>(bytes_read);
111 }
112
113 } // namespace net
114