1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "model/setup/async_manager.h"
18 
19 #include <fcntl.h>        // for fcntl, F_SETFL, O_NONBLOCK
20 #include <gtest/gtest.h>  // for Message, TestPartResult, SuiteApi...
21 #include <netdb.h>        // for gethostbyname, h_addr, hostent
22 #include <netinet/in.h>   // for sockaddr_in, in_addr, INADDR_ANY
23 #include <stdio.h>        // for printf
24 #include <sys/socket.h>   // for socket, AF_INET, accept, bind
25 #include <sys/types.h>    // for in_addr_t
26 #include <time.h>         // for NULL, size_t
27 #include <unistd.h>       // for close, write, read
28 
29 #include <condition_variable>  // for condition_variable
30 #include <cstdint>             // for uint16_t
31 #include <cstring>             // for memset, strcmp, strcpy, strlen
32 #include <mutex>               // for mutex
33 #include <ratio>               // for ratio
34 #include <string>              // for string
35 #include <thread>
36 #include <tuple>  // for tuple
37 
38 namespace rootcanal {
39 
40 class Event {
41 public:
set(bool set=true)42   void set(bool set = true) {
43     std::unique_lock<std::mutex> lk(m_);
44     set_ = set;
45     cv_.notify_all();
46   }
47 
reset()48   void reset() { set(false); }
49 
wait_for(std::chrono::microseconds timeout)50   bool wait_for(std::chrono::microseconds timeout) {
51     std::unique_lock<std::mutex> lk(m_);
52     return cv_.wait_for(lk, timeout, [&] { return set_; });
53   }
54 
operator *()55   bool operator*() { return set_; }
56 
57 private:
58   std::mutex m_;
59   std::condition_variable cv_;
60   bool set_{false};
61 };
62 
63 class AsyncManagerSocketTest : public ::testing::Test {
64 public:
65   static const uint16_t kPort = 6111;
66   static const size_t kBufferSize = 16;
67 
CheckBufferEquals()68   bool CheckBufferEquals() { return strcmp(server_buffer_, client_buffer_) == 0; }
69 
70 protected:
StartServer()71   int StartServer() {
72     struct sockaddr_in serv_addr = {};
73     int fd = socket(AF_INET, SOCK_STREAM, 0);
74     EXPECT_FALSE(fd < 0);
75 
76     serv_addr.sin_family = AF_INET;
77     serv_addr.sin_addr.s_addr = INADDR_ANY;
78     serv_addr.sin_port = htons(kPort);
79     int reuse_flag = 1;
80     EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag, sizeof(reuse_flag)) < 0);
81     EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);
82 
83     listen(fd, 1);
84     return fd;
85   }
86 
AcceptConnection(int fd)87   int AcceptConnection(int fd) {
88     struct sockaddr_in cli_addr;
89     memset(&cli_addr, 0, sizeof(cli_addr));
90     socklen_t clilen = sizeof(cli_addr);
91 
92     int connection_fd = accept(fd, (struct sockaddr*)&cli_addr, &clilen);
93     EXPECT_FALSE(connection_fd < 0);
94 
95     return connection_fd;
96   }
97 
ConnectSocketPair()98   std::tuple<int, int> ConnectSocketPair() {
99     int cli = ConnectClient();
100     WriteFromClient(cli);
101     AwaitServerResponse(cli);
102     int ser = connection_fd_;
103     connection_fd_ = -1;
104     return {cli, ser};
105   }
106 
ReadIncomingMessage(int fd)107   void ReadIncomingMessage(int fd) {
108     int n;
109     do {
110       n = read(fd, server_buffer_, kBufferSize - 1);
111     } while (n == -1 && errno == EAGAIN);
112 
113     if (n == 0 || errno == EBADF) {
114       // Got EOF, or file descriptor disconnected.
115       async_manager_.StopWatchingFileDescriptor(fd);
116       close(fd);
117     } else {
118       ASSERT_GE(n, 0) << strerror(errno);
119       n = write(fd, "1", 1);
120     }
121   }
122 
SetUp()123   void SetUp() override {
124     memset(server_buffer_, 0, kBufferSize);
125     memset(client_buffer_, 0, kBufferSize);
126     socket_fd_ = -1;
127     connection_fd_ = -1;
128 
129     socket_fd_ = StartServer();
130 
131     async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
132       connection_fd_ = AcceptConnection(fd);
133 
134       async_manager_.WatchFdForNonBlockingReads(connection_fd_,
135                                                 [this](int fd) { ReadIncomingMessage(fd); });
136     });
137   }
138 
TearDown()139   void TearDown() override {
140     async_manager_.StopWatchingFileDescriptor(socket_fd_);
141     close(socket_fd_);
142     close(connection_fd_);
143     ASSERT_EQ(std::string_view(server_buffer_, kBufferSize),
144               std::string_view(client_buffer_, kBufferSize));
145   }
146 
ConnectClient()147   int ConnectClient() {
148     int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
149     EXPECT_GE(socket_cli_fd, 0) << strerror(errno);
150 
151     struct hostent* server;
152     server = gethostbyname("localhost");
153     EXPECT_FALSE(server == NULL) << strerror(errno);
154 
155     struct sockaddr_in serv_addr;
156     memset((void*)&serv_addr, 0, sizeof(serv_addr));
157     serv_addr.sin_family = AF_INET;
158     serv_addr.sin_addr.s_addr = *(reinterpret_cast<in_addr_t*>(server->h_addr));
159     serv_addr.sin_port = htons(kPort);
160 
161     int result = connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
162     EXPECT_GE(result, 0) << strerror(errno);
163 
164     return socket_cli_fd;
165   }
166 
WriteFromClient(int socket_cli_fd)167   void WriteFromClient(int socket_cli_fd) {
168     strcpy(client_buffer_, "1");
169     int n = write(socket_cli_fd, client_buffer_, strlen(client_buffer_));
170     ASSERT_GT(n, 0) << strerror(errno);
171   }
172 
AwaitServerResponse(int socket_cli_fd)173   void AwaitServerResponse(int socket_cli_fd) {
174     int n = read(socket_cli_fd, client_buffer_, 1);
175     ASSERT_GT(n, 0) << strerror(errno);
176   }
177 
178 protected:
179   AsyncManager async_manager_;
180   int socket_fd_;
181   int connection_fd_;
182   char server_buffer_[kBufferSize];
183   char client_buffer_[kBufferSize];
184 };
185 
TEST_F(AsyncManagerSocketTest,TestOneConnection)186 TEST_F(AsyncManagerSocketTest, TestOneConnection) {
187   int socket_cli_fd = ConnectClient();
188 
189   WriteFromClient(socket_cli_fd);
190 
191   AwaitServerResponse(socket_cli_fd);
192 
193   close(socket_cli_fd);
194 }
195 
TEST_F(AsyncManagerSocketTest,CanUnsubscribeInCallback)196 TEST_F(AsyncManagerSocketTest, CanUnsubscribeInCallback) {
197   using namespace std::chrono_literals;
198 
199   int socket_cli_fd = ConnectClient();
200   WriteFromClient(socket_cli_fd);
201   AwaitServerResponse(socket_cli_fd);
202   fcntl(connection_fd_, F_SETFL, O_NONBLOCK);
203 
204   std::string data('x', 32);
205 
206   bool stopped = false;
207   async_manager_.WatchFdForNonBlockingReads(connection_fd_, [&](int fd) {
208     async_manager_.StopWatchingFileDescriptor(fd);
209     char buf[32];
210     while (read(fd, buf, sizeof(buf)) > 0)
211       ;
212     stopped = true;
213   });
214 
215   while (!stopped) {
216     write(socket_cli_fd, data.data(), data.size());
217     std::this_thread::sleep_for(5ms);
218   }
219 
220   SUCCEED();
221   close(socket_cli_fd);
222 }
223 
TEST_F(AsyncManagerSocketTest,CanUnsubscribeTaskFromWithinTask)224 TEST_F(AsyncManagerSocketTest, CanUnsubscribeTaskFromWithinTask) {
225   Event running;
226   using namespace std::chrono_literals;
227   async_manager_.ExecAsyncPeriodically(1, 1ms, 2ms, [&running, this]() {
228     EXPECT_TRUE(async_manager_.CancelAsyncTask(1))
229             << "We were scheduled, so cancel should return true";
230     EXPECT_FALSE(async_manager_.CancelAsyncTask(1))
231             << "We were not scheduled, so cancel should return false";
232     running.set(true);
233   });
234 
235   EXPECT_TRUE(running.wait_for(100ms));
236 }
237 
TEST_F(AsyncManagerSocketTest,UnsubScribeWaitsUntilCompletion)238 TEST_F(AsyncManagerSocketTest, UnsubScribeWaitsUntilCompletion) {
239   using namespace std::chrono_literals;
240   Event running;
241   std::atomic<bool> cancel_done = false;
242   std::atomic<bool> task_complete = false;
243   AsyncTaskId task_id = async_manager_.ExecAsyncPeriodically(
244           1, 1ms, 2ms, [&running, &cancel_done, &task_complete]() {
245             // Let the other thread now we are in the callback..
246             running.set(true);
247             // Wee bit of a hack that relies on timing..
248             std::this_thread::sleep_for(20ms);
249             EXPECT_FALSE(cancel_done.load())
250                     << "Task cancellation did not wait for us to complete!";
251             task_complete.store(true);
252           });
253 
254   EXPECT_TRUE(running.wait_for(100ms));
255   auto start = std::chrono::system_clock::now();
256 
257   // There is a 20ms wait.. so we know that this should take some time.
258   EXPECT_TRUE(async_manager_.CancelAsyncTask(task_id))
259           << "We were scheduled, so cancel should return true";
260   cancel_done.store(true);
261   EXPECT_TRUE(task_complete.load()) << "We managed to cancel a task while it was not yet finished.";
262   auto end = std::chrono::system_clock::now();
263   auto passed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
264   EXPECT_GT(passed_ms.count(), 10);
265 }
266 
TEST_F(AsyncManagerSocketTest,NoEventsAfterUnsubscribe)267 TEST_F(AsyncManagerSocketTest, NoEventsAfterUnsubscribe) {
268   // This tests makes sure the AsyncManager never fires an event
269   // after calling StopWatchingFileDescriptor.
270   using clock = std::chrono::system_clock;
271   using namespace std::chrono_literals;
272 
273   clock::time_point time_fast_called;
274   clock::time_point time_slow_called;
275   clock::time_point time_stopped_listening;
276 
277   int round = 0;
278   auto [slow_cli_fd, slow_s_fd] = ConnectSocketPair();
279   fcntl(slow_s_fd, F_SETFL, O_NONBLOCK);
280 
281   auto [fast_cli_fd, fast_s_fd] = ConnectSocketPair();
282   fcntl(fast_s_fd, F_SETFL, O_NONBLOCK);
283 
284   std::string data(1, 'x');
285 
286   // The idea here is as follows:
287   // We want to make sure that an unsubscribed callback never gets called.
288   // This is to make sure we can safely do things like this:
289   //
290   // class Foo {
291   //   Foo(int fd, AsyncManager* am) : fd_(fd), am_(am) {
292   //     am_->WatchFdForNonBlockingReads(
293   //         fd, [&](int fd) { printf("This shouldn't crash! %p\n", this); });
294   //   }
295   //   ~Foo() { am_->StopWatchingFileDescriptor(fd_); }
296   //
297   //   AsyncManager* am_;
298   //   int fd_;
299   // };
300   //
301   // We are going to force a failure as follows:
302   //
303   // The slow callback needs to be called first, if it does not we cannot
304   // force failure, so we have to try multiple times.
305   //
306   // t1, is the thread doing the loop.
307   // t2, is the async manager handler thread.
308   //
309   // t1 will block until the slowcallback.
310   // t2 will now block (for at most 250 ms).
311   // t1 will unsubscribe the fast callback.
312   // 2 cases:
313   //   with bug:
314   //      - t1 takes a timestamp, unblocks t2,
315   //      - t2 invokes the fast callback, and gets a timestamp.
316   //      - Now the unsubscribe time is before the callback time.
317   //   without bug.:
318   //      - t1 locks un unsusbcribe in asyn manager
319   //      - t2 unlocks due to timeout,
320   //      - t2 invokes the fast callback, and gets a timestamp.
321   //      - t1 is unlocked and gets a timestamp.
322   //      - Now the unsubscribe time is after the callback time..
323 
324   do {
325     Event unblock_slow, inslow, infast;
326     time_fast_called = {};
327     time_slow_called = {};
328     time_stopped_listening = {};
329     printf("round: %d\n", round++);
330 
331     // Register fd events
332     async_manager_.WatchFdForNonBlockingReads(slow_s_fd, [&](int /*fd*/) {
333       if (*inslow) {
334         return;
335       }
336       time_slow_called = clock::now();
337       printf("slow: %lld\n", time_slow_called.time_since_epoch().count() % 10000);
338       inslow.set();
339       unblock_slow.wait_for(25ms);
340     });
341 
342     async_manager_.WatchFdForNonBlockingReads(fast_s_fd, [&](int /*fd*/) {
343       if (*infast) {
344         return;
345       }
346       time_fast_called = clock::now();
347       printf("fast: %lld\n", time_fast_called.time_since_epoch().count() % 10000);
348       infast.set();
349     });
350 
351     // Generate fd events
352     write(fast_cli_fd, data.data(), data.size());
353     write(slow_cli_fd, data.data(), data.size());
354 
355     // Block in the right places.
356     if (inslow.wait_for(25ms)) {
357       async_manager_.StopWatchingFileDescriptor(fast_s_fd);
358       time_stopped_listening = clock::now();
359       printf("stop: %lld\n", time_stopped_listening.time_since_epoch().count() % 10000);
360       unblock_slow.set();
361     }
362 
363     infast.wait_for(25ms);
364 
365     // Unregister.
366     async_manager_.StopWatchingFileDescriptor(fast_s_fd);
367     async_manager_.StopWatchingFileDescriptor(slow_s_fd);
368   } while (time_fast_called < time_slow_called);
369 
370   // fast before stop listening.
371   ASSERT_LT(time_fast_called.time_since_epoch().count(),
372             time_stopped_listening.time_since_epoch().count());
373 
374   // Cleanup
375   close(fast_cli_fd);
376   close(fast_s_fd);
377   close(slow_cli_fd);
378   close(slow_s_fd);
379 }
380 
TEST_F(AsyncManagerSocketTest,TestRepeatedConnections)381 TEST_F(AsyncManagerSocketTest, TestRepeatedConnections) {
382   static const int num_connections = 30;
383   for (int i = 0; i < num_connections; i++) {
384     int socket_cli_fd = ConnectClient();
385     WriteFromClient(socket_cli_fd);
386     AwaitServerResponse(socket_cli_fd);
387     close(socket_cli_fd);
388   }
389 }
390 
TEST_F(AsyncManagerSocketTest,TestMultipleConnections)391 TEST_F(AsyncManagerSocketTest, TestMultipleConnections) {
392   static const int num_connections = 30;
393   int socket_cli_fd[num_connections];
394   for (int i = 0; i < num_connections; i++) {
395     socket_cli_fd[i] = ConnectClient();
396     ASSERT_TRUE(socket_cli_fd[i] > 0);
397     WriteFromClient(socket_cli_fd[i]);
398   }
399   for (int i = 0; i < num_connections; i++) {
400     AwaitServerResponse(socket_cli_fd[i]);
401     close(socket_cli_fd[i]);
402   }
403 }
404 
405 class AsyncManagerTest : public ::testing::Test {
406 public:
407   AsyncManager async_manager_;
408 };
409 
TEST_F(AsyncManagerTest,TestSetupTeardown)410 TEST_F(AsyncManagerTest, TestSetupTeardown) {}
411 
TEST_F(AsyncManagerTest,TestCancelTask)412 TEST_F(AsyncManagerTest, TestCancelTask) {
413   AsyncUserId user1 = async_manager_.GetNextUserId();
414   bool task1_ran = false;
415   bool* task1_ran_ptr = &task1_ran;
416   AsyncTaskId task1_id = async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
417                                                   [task1_ran_ptr]() { *task1_ran_ptr = true; });
418   ASSERT_TRUE(async_manager_.CancelAsyncTask(task1_id));
419   ASSERT_FALSE(task1_ran);
420 }
421 
TEST_F(AsyncManagerTest,TestCancelLongTask)422 TEST_F(AsyncManagerTest, TestCancelLongTask) {
423   AsyncUserId user1 = async_manager_.GetNextUserId();
424   bool task1_ran = false;
425   bool* task1_ran_ptr = &task1_ran;
426   AsyncTaskId task1_id = async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
427                                                   [task1_ran_ptr]() { *task1_ran_ptr = true; });
428   bool task2_ran = false;
429   bool* task2_ran_ptr = &task2_ran;
430   AsyncTaskId task2_id = async_manager_.ExecAsync(user1, std::chrono::seconds(2),
431                                                   [task2_ran_ptr]() { *task2_ran_ptr = true; });
432   ASSERT_FALSE(task1_ran);
433   ASSERT_FALSE(task2_ran);
434   while (!task1_ran)
435     ;
436   ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
437   ASSERT_FALSE(task2_ran);
438   ASSERT_TRUE(async_manager_.CancelAsyncTask(task2_id));
439 }
440 
TEST_F(AsyncManagerTest,TestCancelAsyncTasksFromUser)441 TEST_F(AsyncManagerTest, TestCancelAsyncTasksFromUser) {
442   AsyncUserId user1 = async_manager_.GetNextUserId();
443   AsyncUserId user2 = async_manager_.GetNextUserId();
444   bool task1_ran = false;
445   bool* task1_ran_ptr = &task1_ran;
446   bool task2_ran = false;
447   bool* task2_ran_ptr = &task2_ran;
448   bool task3_ran = false;
449   bool* task3_ran_ptr = &task3_ran;
450   bool task4_ran = false;
451   bool* task4_ran_ptr = &task4_ran;
452   bool task5_ran = false;
453   bool* task5_ran_ptr = &task5_ran;
454   AsyncTaskId task1_id = async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
455                                                   [task1_ran_ptr]() { *task1_ran_ptr = true; });
456   AsyncTaskId task2_id = async_manager_.ExecAsync(user1, std::chrono::seconds(2),
457                                                   [task2_ran_ptr]() { *task2_ran_ptr = true; });
458   AsyncTaskId task3_id = async_manager_.ExecAsync(user1, std::chrono::milliseconds(2),
459                                                   [task3_ran_ptr]() { *task3_ran_ptr = true; });
460   AsyncTaskId task4_id = async_manager_.ExecAsync(user1, std::chrono::seconds(2),
461                                                   [task4_ran_ptr]() { *task4_ran_ptr = true; });
462   AsyncTaskId task5_id = async_manager_.ExecAsync(user2, std::chrono::milliseconds(2),
463                                                   [task5_ran_ptr]() { *task5_ran_ptr = true; });
464   ASSERT_FALSE(task1_ran);
465   while (!task1_ran || !task3_ran || !task5_ran)
466     ;
467   ASSERT_TRUE(task1_ran);
468   ASSERT_FALSE(task2_ran);
469   ASSERT_TRUE(task3_ran);
470   ASSERT_FALSE(task4_ran);
471   ASSERT_TRUE(task5_ran);
472   async_manager_.CancelAsyncTasksFromUser(user1);
473   ASSERT_FALSE(async_manager_.CancelAsyncTask(task1_id));
474   ASSERT_FALSE(async_manager_.CancelAsyncTask(task2_id));
475   ASSERT_FALSE(async_manager_.CancelAsyncTask(task3_id));
476   ASSERT_FALSE(async_manager_.CancelAsyncTask(task4_id));
477   ASSERT_FALSE(async_manager_.CancelAsyncTask(task5_id));
478 }
479 
480 }  // namespace rootcanal
481