1 /*
2 * Copyright 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "model/setup/async_manager.h"
18
19 #include <fcntl.h> // for fcntl, F_SETFL, O_NONBLOCK
20 #include <gtest/gtest.h> // for Message, TestPartResult, SuiteApi...
21 #include <netdb.h> // for gethostbyname, h_addr, hostent
22 #include <netinet/in.h> // for sockaddr_in, in_addr, INADDR_ANY
23 #include <stdio.h> // for printf
24 #include <sys/socket.h> // for socket, AF_INET, accept, bind
25 #include <sys/types.h> // for in_addr_t
26 #include <time.h> // for NULL, size_t
27 #include <unistd.h> // for close, write, read
28
29 #include <condition_variable> // for condition_variable
30 #include <cstdint> // for uint16_t
31 #include <cstring> // for memset, strcmp, strcpy, strlen
32 #include <mutex> // for mutex
33 #include <ratio> // for ratio
34 #include <string> // for string
35 #include <thread>
36 #include <tuple> // for tuple
37
38 namespace rootcanal {
39
40 class Event {
41 public:
set(bool set=true)42 void set(bool set = true) {
43 std::unique_lock<std::mutex> lk(m_);
44 set_ = set;
45 cv_.notify_all();
46 }
47
reset()48 void reset() { set(false); }
49
wait_for(std::chrono::microseconds timeout)50 bool wait_for(std::chrono::microseconds timeout) {
51 std::unique_lock<std::mutex> lk(m_);
52 return cv_.wait_for(lk, timeout, [&] { return set_; });
53 }
54
operator *()55 bool operator*() { return set_; }
56
57 private:
58 std::mutex m_;
59 std::condition_variable cv_;
60 bool set_{false};
61 };
62
63 class AsyncManagerSocketTest : public ::testing::Test {
64 public:
65 static const uint16_t kPort = 6111;
66 static const size_t kBufferSize = 16;
67
CheckBufferEquals()68 bool CheckBufferEquals() { return strcmp(server_buffer_, client_buffer_) == 0; }
69
70 protected:
StartServer()71 int StartServer() {
72 struct sockaddr_in serv_addr = {};
73 int fd = socket(AF_INET, SOCK_STREAM, 0);
74 EXPECT_FALSE(fd < 0);
75
76 serv_addr.sin_family = AF_INET;
77 serv_addr.sin_addr.s_addr = INADDR_ANY;
78 serv_addr.sin_port = htons(kPort);
79 int reuse_flag = 1;
80 EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag, sizeof(reuse_flag)) < 0);
81 EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);
82
83 listen(fd, 1);
84 return fd;
85 }
86
AcceptConnection(int fd)87 int AcceptConnection(int fd) {
88 struct sockaddr_in cli_addr;
89 memset(&cli_addr, 0, sizeof(cli_addr));
90 socklen_t clilen = sizeof(cli_addr);
91
92 int connection_fd = accept(fd, (struct sockaddr*)&cli_addr, &clilen);
93 EXPECT_FALSE(connection_fd < 0);
94
95 return connection_fd;
96 }
97
ConnectSocketPair()98 std::tuple<int, int> ConnectSocketPair() {
99 int cli = ConnectClient();
100 WriteFromClient(cli);
101 AwaitServerResponse(cli);
102 int ser = connection_fd_;
103 connection_fd_ = -1;
104 return {cli, ser};
105 }
106
ReadIncomingMessage(int fd)107 void ReadIncomingMessage(int fd) {
108 int n;
109 do {
110 n = read(fd, server_buffer_, kBufferSize - 1);
111 } while (n == -1 && errno == EAGAIN);
112
113 if (n == 0 || errno == EBADF) {
114 // Got EOF, or file descriptor disconnected.
115 async_manager_.StopWatchingFileDescriptor(fd);
116 close(fd);
117 } else {
118 ASSERT_GE(n, 0) << strerror(errno);
119 n = write(fd, "1", 1);
120 }
121 }
122
SetUp()123 void SetUp() override {
124 memset(server_buffer_, 0, kBufferSize);
125 memset(client_buffer_, 0, kBufferSize);
126 socket_fd_ = -1;
127 connection_fd_ = -1;
128
129 socket_fd_ = StartServer();
130
131 async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
132 connection_fd_ = AcceptConnection(fd);
133
134 async_manager_.WatchFdForNonBlockingReads(connection_fd_,
135 [this](int fd) { ReadIncomingMessage(fd); });
136 });
137 }
138
TearDown()139 void TearDown() override {
140 async_manager_.StopWatchingFileDescriptor(socket_fd_);
141 close(socket_fd_);
142 close(connection_fd_);
143 ASSERT_EQ(std::string_view(server_buffer_, kBufferSize),
144 std::string_view(client_buffer_, kBufferSize));
145 }
146
ConnectClient()147 int ConnectClient() {
148 int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
149 EXPECT_GE(socket_cli_fd, 0) << strerror(errno);
150
151 struct hostent* server;
152 server = gethostbyname("localhost");
153 EXPECT_FALSE(server == NULL) << strerror(errno);
154
155 struct sockaddr_in serv_addr;
156 memset((void*)&serv_addr, 0, sizeof(serv_addr));
157 serv_addr.sin_family = AF_INET;
158 serv_addr.sin_addr.s_addr = *(reinterpret_cast<in_addr_t*>(server->h_addr));
159 serv_addr.sin_port = htons(kPort);
160
161 int result = connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
162 EXPECT_GE(result, 0) << strerror(errno);
163
164 return socket_cli_fd;
165 }
166
WriteFromClient(int socket_cli_fd)167 void WriteFromClient(int socket_cli_fd) {
168 strcpy(client_buffer_, "1");
169 int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
170 ASSERT_GT(n, 0) << strerror(errno);
171 }
172
AwaitServerResponse(int socket_cli_fd)173 void AwaitServerResponse(int socket_cli_fd) {
174 int n = read(socket_cli_fd, client_buffer_, 1);
175 ASSERT_GT(n, 0) << strerror(errno);
176 }
177
178 protected:
179 AsyncManager async_manager_;
180 int socket_fd_;
181 int connection_fd_;
182 char server_buffer_[kBufferSize];
183 char client_buffer_[kBufferSize];
184 };
185
TEST_F(AsyncManagerSocketTest,TestOneConnection)186 TEST_F(AsyncManagerSocketTest, TestOneConnection) {
187 int socket_cli_fd = ConnectClient();
188
189 WriteFromClient(socket_cli_fd);
190
191 AwaitServerResponse(socket_cli_fd);
192
193 close(socket_cli_fd);
194 }
195
TEST_F(AsyncManagerSocketTest,CanUnsubscribeInCallback)196 TEST_F(AsyncManagerSocketTest, CanUnsubscribeInCallback) {
197 using namespace std::chrono_literals;
198
199 int socket_cli_fd = ConnectClient();
200 WriteFromClient(socket_cli_fd);
201 AwaitServerResponse(socket_cli_fd);
202 fcntl(connection_fd_, F_SETFL, O_NONBLOCK);
203
204 std::string data('x', 32);
205
206 bool stopped = false;
207 async_manager_.WatchFdForNonBlockingReads(connection_fd_, [&](int fd) {
208 async_manager_.StopWatchingFileDescriptor(fd);
209 char buf[32];
210 while (read(fd, buf, sizeof(buf)) > 0)
211 ;
212 stopped = true;
213 });
214
215 while (!stopped) {
216 write(socket_cli_fd, data.data(), data.size());
217 std::this_thread::sleep_for(5ms);
218 }
219
220 SUCCEED();
221 close(socket_cli_fd);
222 }
223
TEST_F(AsyncManagerSocketTest,CanUnsubscribeTaskFromWithinTask)224 TEST_F(AsyncManagerSocketTest, CanUnsubscribeTaskFromWithinTask) {
225 Event running;
226 using namespace std::chrono_literals;
227 async_manager_.ExecAsyncPeriodically(1, 1ms, 2ms, [&running, this]() {
228 EXPECT_TRUE(async_manager_.CancelAsyncTask(1))
229 << "We were scheduled, so cancel should return true";
230 EXPECT_FALSE(async_manager_.CancelAsyncTask(1))
231 << "We were not scheduled, so cancel should return false";
232 running.set(true);
233 });
234
235 EXPECT_TRUE(running.wait_for(100ms));
236 }
237
TEST_F(AsyncManagerSocketTest,UnsubScribeWaitsUntilCompletion)238 TEST_F(AsyncManagerSocketTest, UnsubScribeWaitsUntilCompletion) {
239 using namespace std::chrono_literals;
240 Event running;
241 std::atomic<bool> cancel_done = false;
242 std::atomic<bool> task_complete = false;
243 AsyncTaskId task_id = async_manager_.ExecAsyncPeriodically(
244 1, 1ms, 2ms, [&running, &cancel_done, &task_complete]() {
245 // Let the other thread now we are in the callback..
246 running.set(true);
247 // Wee bit of a hack that relies on timing..
248 std::this_thread::sleep_for(20ms);
249 EXPECT_FALSE(cancel_done.load())
250 << "Task cancellation did not wait for us to complete!";
251 task_complete.store(true);
252 });
253
254 EXPECT_TRUE(running.wait_for(100ms));
255 auto start = std::chrono::system_clock::now();
256
257 // There is a 20ms wait.. so we know that this should take some time.
258 EXPECT_TRUE(async_manager_.CancelAsyncTask(task_id))
259 << "We were scheduled, so cancel should return true";
260 cancel_done.store(true);
261 EXPECT_TRUE(task_complete.load()) << "We managed to cancel a task while it was not yet finished.";
262 auto end = std::chrono::system_clock::now();
263 auto passed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
264 EXPECT_GT(passed_ms.count(), 10);
265 }
266
TEST_F(AsyncManagerSocketTest,NoEventsAfterUnsubscribe)267 TEST_F(AsyncManagerSocketTest, NoEventsAfterUnsubscribe) {
268 // This tests makes sure the AsyncManager never fires an event
269 // after calling StopWatchingFileDescriptor.
270 using clock = std::chrono::system_clock;
271 using namespace std::chrono_literals;
272
273 clock::time_point time_fast_called;
274 clock::time_point time_slow_called;
275 clock::time_point time_stopped_listening;
276
277 int round = 0;
278 auto [slow_cli_fd, slow_s_fd] = ConnectSocketPair();
279 fcntl(slow_s_fd, F_SETFL, O_NONBLOCK);
280
281 auto [fast_cli_fd, fast_s_fd] = ConnectSocketPair();
282 fcntl(fast_s_fd, F_SETFL, O_NONBLOCK);
283
284 std::string data(1, 'x');
285
286 // The idea here is as follows:
287 // We want to make sure that an unsubscribed callback never gets called.
288 // This is to make sure we can safely do things like this:
289 //
290 // class Foo {
291 // Foo(int fd, AsyncManager* am) : fd_(fd), am_(am) {
292 // am_->WatchFdForNonBlockingReads(
293 // fd, [&](int fd) { printf("This shouldn't crash! %p\n", this); });
294 // }
295 // ~Foo() { am_->StopWatchingFileDescriptor(fd_); }
296 //
297 // AsyncManager* am_;
298 // int fd_;
299 // };
300 //
301 // We are going to force a failure as follows:
302 //
303 // The slow callback needs to be called first, if it does not we cannot
304 // force failure, so we have to try multiple times.
305 //
306 // t1, is the thread doing the loop.
307 // t2, is the async manager handler thread.
308 //
309 // t1 will block until the slowcallback.
310 // t2 will now block (for at most 250 ms).
311 // t1 will unsubscribe the fast callback.
312 // 2 cases:
313 // with bug:
314 // - t1 takes a timestamp, unblocks t2,
315 // - t2 invokes the fast callback, and gets a timestamp.
316 // - Now the unsubscribe time is before the callback time.
317 // without bug.:
318 // - t1 locks un unsusbcribe in asyn manager
319 // - t2 unlocks due to timeout,
320 // - t2 invokes the fast callback, and gets a timestamp.
321 // - t1 is unlocked and gets a timestamp.
322 // - Now the unsubscribe time is after the callback time..
323
324 do {
325 Event unblock_slow, inslow, infast;
326 time_fast_called = {};
327 time_slow_called = {};
328 time_stopped_listening = {};
329 printf("round: %d\n", round++);
330
331 // Register fd events
332 async_manager_.WatchFdForNonBlockingReads(slow_s_fd, [&](int /*fd*/) {
333 if (*inslow) {
334 return;
335 }
336 time_slow_called = clock::now();
337 printf("slow: %lld\n", time_slow_called.time_since_epoch().count() % 10000);
338 inslow.set();
339 unblock_slow.wait_for(25ms);
340 });
341
342 async_manager_.WatchFdForNonBlockingReads(fast_s_fd, [&](int /*fd*/) {
343 if (*infast) {
344 return;
345 }
346 time_fast_called = clock::now();
347 printf("fast: %lld\n", time_fast_called.time_since_epoch().count() % 10000);
348 infast.set();
349 });
350
351 // Generate fd events
352 write(fast_cli_fd, data.data(), data.size());
353 write(slow_cli_fd, data.data(), data.size());
354
355 // Block in the right places.
356 if (inslow.wait_for(25ms)) {
357 async_manager_.StopWatchingFileDescriptor(fast_s_fd);
358 time_stopped_listening = clock::now();
359 printf("stop: %lld\n", time_stopped_listening.time_since_epoch().count() % 10000);
360 unblock_slow.set();
361 }
362
363 infast.wait_for(25ms);
364
365 // Unregister.
366 async_manager_.StopWatchingFileDescriptor(fast_s_fd);
367 async_manager_.StopWatchingFileDescriptor(slow_s_fd);
368 } while (time_fast_called < time_slow_called);
369
370 // fast before stop listening.
371 ASSERT_LT(time_fast_called.time_since_epoch().count(),
372 time_stopped_listening.time_since_epoch().count());
373
374 // Cleanup
375 close(fast_cli_fd);
376 close(fast_s_fd);
377 close(slow_cli_fd);
378 close(slow_s_fd);
379 }
380
TEST_F(AsyncManagerSocketTest,TestRepeatedConnections)381 TEST_F(AsyncManagerSocketTest, TestRepeatedConnections) {
382 static const int num_connections = 30;
383 for (int i = 0; i < num_connections; i++) {
384 int socket_cli_fd = ConnectClient();
385 WriteFromClient(socket_cli_fd);
386 AwaitServerResponse(socket_cli_fd);
387 close(socket_cli_fd);
388 }
389 }
390
TEST_F(AsyncManagerSocketTest,TestMultipleConnections)391 TEST_F(AsyncManagerSocketTest, TestMultipleConnections) {
392 static const int num_connections = 30;
393 int socket_cli_fd[num_connections];
394 for (int i = 0; i < num_connections; i++) {
395 socket_cli_fd[i] = ConnectClient();
396 ASSERT_TRUE(socket_cli_fd[i] > 0);
397 WriteFromClient(socket_cli_fd[i]);
398 }
399 for (int i = 0; i < num_connections; i++) {
400 AwaitServerResponse(socket_cli_fd[i]);
401 close(socket_cli_fd[i]);
402 }
403 }
404
405 class AsyncManagerTest : public ::testing::Test {
406 public:
407 AsyncManager async_manager_;
408 };
409
TEST_F(AsyncManagerTest,TestSetupTeardown)410 TEST_F(AsyncManagerTest, TestSetupTeardown) {}
411
TEST_F(AsyncManagerTest,TestCancelTask)412 TEST_F(AsyncManagerTest, TestCancelTask) {
413 AsyncUserId user1 = async_manager_.GetNextUserId();
414 bool task1_ran = false;
415 bool* task1_ran_ptr = &task1_ran;
416 AsyncTaskId task1_id = async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
417 [task1_ran_ptr]() { *task1_ran_ptr = true; });
418 ASSERT_TRUE(async_manager_.CancelAsyncTask(task1_id));
419 ASSERT_FALSE(task1_ran);
420 }
421
TEST_F(AsyncManagerTest,TestCancelLongTask)422 TEST_F(AsyncManagerTest, TestCancelLongTask) {
423 AsyncUserId user1 = async_manager_.GetNextUserId();
424 bool task1_ran = false;
425 bool* task1_ran_ptr = &task1_ran;
426 AsyncTaskId task1_id = async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
427 [task1_ran_ptr]() { *task1_ran_ptr = true; });
428 bool task2_ran = false;
429 bool* task2_ran_ptr = &task2_ran;
430 AsyncTaskId task2_id = async_manager_.ExecAsync(user1, std::chrono::seconds(2),
431 [task2_ran_ptr]() { *task2_ran_ptr = true; });
432 ASSERT_FALSE(task1_ran);
433 ASSERT_FALSE(task2_ran);
434 while (!task1_ran)
435 ;
436 ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
437 ASSERT_FALSE(task2_ran);
438 ASSERT_TRUE(async_manager_.CancelAsyncTask(task2_id));
439 }
440
TEST_F(AsyncManagerTest,TestCancelAsyncTasksFromUser)441 TEST_F(AsyncManagerTest, TestCancelAsyncTasksFromUser) {
442 AsyncUserId user1 = async_manager_.GetNextUserId();
443 AsyncUserId user2 = async_manager_.GetNextUserId();
444 bool task1_ran = false;
445 bool* task1_ran_ptr = &task1_ran;
446 bool task2_ran = false;
447 bool* task2_ran_ptr = &task2_ran;
448 bool task3_ran = false;
449 bool* task3_ran_ptr = &task3_ran;
450 bool task4_ran = false;
451 bool* task4_ran_ptr = &task4_ran;
452 bool task5_ran = false;
453 bool* task5_ran_ptr = &task5_ran;
454 AsyncTaskId task1_id = async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
455 [task1_ran_ptr]() { *task1_ran_ptr = true; });
456 AsyncTaskId task2_id = async_manager_.ExecAsync(user1, std::chrono::seconds(2),
457 [task2_ran_ptr]() { *task2_ran_ptr = true; });
458 AsyncTaskId task3_id = async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
459 [task3_ran_ptr]() { *task3_ran_ptr = true; });
460 AsyncTaskId task4_id = async_manager_.ExecAsync(user1, std::chrono::seconds(2),
461 [task4_ran_ptr]() { *task4_ran_ptr = true; });
462 AsyncTaskId task5_id = async_manager_.ExecAsync(user2, std::chrono::milliseconds(2),
463 [task5_ran_ptr]() { *task5_ran_ptr = true; });
464 ASSERT_FALSE(task1_ran);
465 while (!task1_ran || !task3_ran || !task5_ran)
466 ;
467 ASSERT_TRUE(task1_ran);
468 ASSERT_FALSE(task2_ran);
469 ASSERT_TRUE(task3_ran);
470 ASSERT_FALSE(task4_ran);
471 ASSERT_TRUE(task5_ran);
472 async_manager_.CancelAsyncTasksFromUser(user1);
473 ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
474 ASSERT_FALSE(async_manager_.CancelAsyncTask(task2_id));
475 ASSERT_FALSE(async_manager_.CancelAsyncTask(task3_id));
476 ASSERT_FALSE(async_manager_.CancelAsyncTask(task4_id));
477 ASSERT_FALSE(async_manager_.CancelAsyncTask(task5_id));
478 }
479
480 } // namespace rootcanal
481