1 // Copyright 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "test_utils.h" // NOLINT(build/include)
16
17 #include <absl/strings/match.h>
18 #include <fcntl.h>
19 #include <netdb.h>
20 #include <netinet/in.h>
21 #include <sys/socket.h>
22 #include <sys/types.h>
23 #include <unistd.h>
24
25 #include <memory>
26 #include <thread> // NOLINT(build/c++11)
27
28 #include "absl/status/statusor.h"
29 #include "sandboxed_api/util/status_macros.h"
30
31 namespace curl::tests {
32
33 int CurlTestUtils::port_;
34
35 std::thread CurlTestUtils::server_thread_;
36
CurlTestSetUp()37 absl::Status CurlTestUtils::CurlTestSetUp() {
38 // Initialize sandbox2 and SAPI
39 sandbox_ = std::make_unique<curl::CurlSapiSandbox>();
40 SAPI_RETURN_IF_ERROR(sandbox_->Init());
41 api_ = std::make_unique<curl::CurlApi>(sandbox_.get());
42
43 // Initialize curl
44 SAPI_ASSIGN_OR_RETURN(curl::CURL * curl_handle, api_->curl_easy_init());
45 if (!curl_handle) {
46 return absl::UnavailableError("curl_easy_init returned nullptr");
47 }
48 curl_ = std::make_unique<sapi::v::RemotePtr>(curl_handle);
49
50 int curl_code = 0;
51
52 // Specify request URL
53 sapi::v::ConstCStr sapi_url(kUrl);
54 SAPI_ASSIGN_OR_RETURN(
55 curl_code, api_->curl_easy_setopt_ptr(curl_.get(), curl::CURLOPT_URL,
56 sapi_url.PtrBefore()));
57 if (curl_code != curl::CURLE_OK) {
58 return absl::UnavailableError(absl::StrCat(
59 "curl_easy_setopt_ptr returned with the error code ", curl_code));
60 }
61
62 // Set port
63 SAPI_ASSIGN_OR_RETURN(curl_code, api_->curl_easy_setopt_long(
64 curl_.get(), curl::CURLOPT_PORT, port_));
65 if (curl_code != curl::CURLE_OK) {
66 return absl::UnavailableError(absl::StrCat(
67 "curl_easy_setopt_long returned with the error code ", curl_code));
68 }
69
70 // Generate pointer to the WriteToMemory callback
71 void* function_ptr;
72 SAPI_RETURN_IF_ERROR(
73 sandbox_->rpc_channel()->Symbol("WriteToMemory", &function_ptr));
74 sapi::v::RemotePtr remote_function_ptr(function_ptr);
75
76 // Set WriteToMemory as the write function
77 SAPI_ASSIGN_OR_RETURN(curl_code, api_->curl_easy_setopt_ptr(
78 curl_.get(), curl::CURLOPT_WRITEFUNCTION,
79 &remote_function_ptr));
80 if (curl_code != curl::CURLE_OK) {
81 return absl::UnavailableError(absl::StrCat(
82 "curl_easy_setopt_ptr returned with the error code ", curl_code));
83 }
84
85 // Pass memory chunk object to the callback
86 chunk_ = std::make_unique<sapi::v::LenVal>(0);
87 SAPI_ASSIGN_OR_RETURN(
88 curl_code, api_->curl_easy_setopt_ptr(
89 curl_.get(), curl::CURLOPT_WRITEDATA, chunk_->PtrBoth()));
90 if (curl_code != curl::CURLE_OK) {
91 return absl::UnavailableError(absl::StrCat(
92 "curl_easy_setopt_ptr returned with the error code ", curl_code));
93 }
94
95 return absl::OkStatus();
96 }
97
CurlTestTearDown()98 absl::Status CurlTestUtils::CurlTestTearDown() {
99 // Cleanup curl
100 return api_->curl_easy_cleanup(curl_.get());
101 }
102
PerformRequest()103 absl::StatusOr<std::string> CurlTestUtils::PerformRequest() {
104 // Perform the request
105 SAPI_ASSIGN_OR_RETURN(int curl_code, api_->curl_easy_perform(curl_.get()));
106 if (curl_code != curl::CURLE_OK) {
107 return absl::UnavailableError(absl::StrCat(
108 "curl_easy_perform returned with the error code ", curl_code));
109 }
110
111 // Get pointer to the memory chunk
112 SAPI_RETURN_IF_ERROR(sandbox_->TransferFromSandboxee(chunk_.get()));
113 return std::string(reinterpret_cast<char*>(chunk_->GetData()));
114 }
115
116 namespace {
117
118 // Read the socket until str is completely read
ReadUntil(const int socket,const std::string & str,const size_t max_request_size)119 std::string ReadUntil(const int socket, const std::string& str,
120 const size_t max_request_size) {
121 std::string str_read;
122 str_read.reserve(max_request_size);
123
124 // Read one char at a time until str is suffix of buf
125 while (!absl::EndsWith(str_read, str)) {
126 char next_char;
127 if (str_read.size() >= max_request_size ||
128 read(socket, &next_char, 1) < 1) {
129 return "";
130 }
131 str_read += next_char;
132 }
133
134 return str_read;
135 }
136
137 // Parse HTTP headers to return the Content-Length
GetContentLength(const std::string & headers)138 ssize_t GetContentLength(const std::string& headers) {
139 constexpr char kContentLength[] = "Content-Length: ";
140 // Find the Content-Length header
141 const auto length_header_start = headers.find(kContentLength);
142
143 // There is no Content-Length field
144 if (length_header_start == std::string::npos) {
145 return 0;
146 }
147
148 // Find Content-Length string
149 const auto length_start = length_header_start + strlen(kContentLength);
150 const auto length_bytes = headers.find("\r\n", length_start) - length_start;
151
152 // length_bytes exceeds maximum
153 if (length_bytes >= 64) {
154 return -1;
155 }
156
157 // Convert string to int and return
158 return std::stoi(headers.substr(length_start, length_bytes));
159 }
160
161 // Read exactly content_bytes from the socket
ReadExact(int socket,size_t content_bytes)162 std::string ReadExact(int socket, size_t content_bytes) {
163 std::string str_read;
164 str_read.reserve(content_bytes);
165
166 // Read one char at a time until all chars are read
167 while (str_read.size() < content_bytes) {
168 char next_char;
169 if (read(socket, &next_char, 1) < 1) {
170 return "";
171 }
172 str_read += next_char;
173 }
174
175 return str_read;
176 }
177
178 // Listen on the socket and answer back to requests
ServerLoop(int listening_socket,sockaddr_in socket_address)179 void ServerLoop(int listening_socket, sockaddr_in socket_address) {
180 socklen_t socket_address_size = sizeof(socket_address);
181
182 // Listen on the socket (maximum 1 connection)
183 if (listen(listening_socket, 1) == -1) {
184 return;
185 }
186
187 // Keep accepting connections until the thread is terminated
188 // (i.e. server_thread_ is assigned to a new thread or destroyed)
189 for (;;) {
190 // File descriptor to the connection socket
191 // This blocks the thread until a connection is established
192 int accepted_socket =
193 accept(listening_socket, reinterpret_cast<sockaddr*>(&socket_address),
194 reinterpret_cast<socklen_t*>(&socket_address_size));
195 if (accepted_socket == -1) {
196 return;
197 }
198
199 constexpr int kMaxRequestSize = 4096;
200
201 // Read until the end of the headers
202 std::string headers =
203 ReadUntil(accepted_socket, "\r\n\r\n", kMaxRequestSize);
204
205 if (headers == "") {
206 close(accepted_socket);
207 return;
208 }
209
210 // Get the length of the request content
211 ssize_t content_length = GetContentLength(headers);
212 if (content_length > kMaxRequestSize - headers.size() ||
213 content_length < 0) {
214 close(accepted_socket);
215 return;
216 }
217
218 // Read the request content
219 std::string content = ReadExact(accepted_socket, content_length);
220
221 // Prepare a response for the request
222 std::string http_response =
223 "HTTP/1.1 200 OK\nContent-Type: text/plain\nContent-Length: ";
224
225 if (headers.substr(0, 3) == "GET") {
226 http_response += "2\r\n\r\nOK";
227
228 } else if (headers.substr(0, 4) == "POST") {
229 http_response +=
230 std::to_string(content.size()) + "\r\n\r\n" + std::string{content};
231
232 } else {
233 close(accepted_socket);
234 return;
235 }
236
237 // Ignore any errors, the connection will be closed anyway
238 write(accepted_socket, http_response.c_str(), http_response.size());
239
240 // Close the socket
241 close(accepted_socket);
242 }
243 }
244
245 } // namespace
246
StartMockServer()247 void CurlTestUtils::StartMockServer() {
248 // Get the socket file descriptor
249 int listening_socket = socket(AF_INET, SOCK_STREAM, 0);
250
251 // Create the socket address object
252 // The port is set to 0, meaning that it will be auto assigned
253 // Only local connections can access this socket
254 sockaddr_in socket_address{AF_INET, 0, htonl(INADDR_LOOPBACK)};
255 socklen_t socket_address_size = sizeof(socket_address);
256 if (listening_socket == -1) {
257 return;
258 }
259
260 // Bind the file descriptor to the socket address object
261 if (bind(listening_socket, reinterpret_cast<sockaddr*>(&socket_address),
262 socket_address_size) == -1) {
263 return;
264 }
265
266 // Assign an available port to the socket address object
267 if (getsockname(listening_socket,
268 reinterpret_cast<sockaddr*>(&socket_address),
269 &socket_address_size) == -1) {
270 return;
271 }
272
273 // Get the port number
274 port_ = ntohs(socket_address.sin_port);
275
276 // Set server_thread_ operation to socket listening
277 server_thread_ = std::thread(ServerLoop, listening_socket, socket_address);
278 }
279
280 } // namespace curl::tests
281