1 /******************************************************************************
2  *
3  *  Copyright 2020 Google, Inc.
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at:
8  *
9  *  http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  ******************************************************************************/
18 
19 #include "os/internal/wakelock_native.h"
20 
21 #include <aidl/android/system/suspend/BnSuspendCallback.h>
22 #include <aidl/android/system/suspend/BnWakelockCallback.h>
23 #include <aidl/android/system/suspend/ISuspendControlService.h>
24 #include <android/binder_auto_utils.h>
25 #include <android/binder_interface_utils.h>
26 #include <android/binder_manager.h>
27 #include <android/binder_process.h>
28 #include <gtest/gtest.h>
29 
30 #include <chrono>
31 #include <future>
32 #include <memory>
33 #include <mutex>
34 
35 namespace testing {
36 
37 using aidl::android::system::suspend::BnSuspendCallback;
38 using aidl::android::system::suspend::BnWakelockCallback;
39 using aidl::android::system::suspend::ISuspendControlService;
40 using bluetooth::os::internal::WakelockNative;
41 using ndk::ScopedAStatus;
42 using ndk::SharedRefBase;
43 using ndk::SpAIBinder;
44 
45 static const std::string kTestWakelockName = "BtWakelockNativeTestLock";
46 
47 static std::recursive_mutex mutex;
48 static std::unique_ptr<std::promise<void>> acquire_promise = nullptr;
49 static std::unique_ptr<std::promise<void>> release_promise = nullptr;
50 
51 class PromiseFutureContext {
52 public:
FulfilPromise(std::unique_ptr<std::promise<void>> & promise)53   static void FulfilPromise(std::unique_ptr<std::promise<void>>& promise) {
54     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
55     if (promise != nullptr) {
56       std::promise<void>* prom = promise.release();
57       prom->set_value();
58       delete prom;
59     }
60   }
61 
PromiseFutureContext(std::unique_ptr<std::promise<void>> & promise,bool expect_fulfillment)62   explicit PromiseFutureContext(std::unique_ptr<std::promise<void>>& promise,
63                                 bool expect_fulfillment)
64       : promise_(promise), expect_fulfillment_(expect_fulfillment) {
65     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
66     EXPECT_EQ(promise_, nullptr);
67     promise_ = std::make_unique<std::promise<void>>();
68     future_ = promise->get_future();
69   }
70 
~PromiseFutureContext()71   ~PromiseFutureContext() {
72     auto future_status = future_.wait_for(std::chrono::seconds(2));
73     if (expect_fulfillment_) {
74       EXPECT_EQ(future_status, std::future_status::ready);
75     } else {
76       EXPECT_NE(future_status, std::future_status::ready);
77     }
78     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
79     promise_ = nullptr;
80   }
81 
82 private:
83   std::unique_ptr<std::promise<void>>& promise_;
84   bool expect_fulfillment_ = true;
85   std::future<void> future_;
86 };
87 
88 class WakelockCallback : public BnWakelockCallback {
89 public:
notifyAcquired()90   ScopedAStatus notifyAcquired() override {
91     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
92     net_acquired_count++;
93     fprintf(stderr, "notifyAcquired, count = %d\n", net_acquired_count);
94     PromiseFutureContext::FulfilPromise(acquire_promise);
95     return ScopedAStatus::ok();
96   }
notifyReleased()97   ScopedAStatus notifyReleased() override {
98     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
99     net_acquired_count--;
100     fprintf(stderr, "notifyReleased, count = %d\n", net_acquired_count);
101     PromiseFutureContext::FulfilPromise(release_promise);
102     return ScopedAStatus::ok();
103   }
104 
105   int net_acquired_count = 0;
106 };
107 
108 class SuspendCallback : public BnSuspendCallback {
109 public:
notifyWakeup(bool,const std::vector<std::string> &)110   ScopedAStatus notifyWakeup(bool /* success */,
111                              const std::vector<std::string>& /* wakeup_reasons */) override {
112     std::lock_guard<std::recursive_mutex> lock_guard(mutex);
113     fprintf(stderr, "notifyWakeup\n");
114     return ScopedAStatus::ok();
115   }
116 };
117 
118 // There is no way to unregister these callbacks besides when this process dies
119 // Hence, we want to have only one copy of these callbacks per process
120 static std::shared_ptr<SuspendCallback> suspend_callback = nullptr;
121 static std::shared_ptr<WakelockCallback> control_callback = nullptr;
122 
123 class WakelockNativeTest : public Test {
124 protected:
SetUp()125   void SetUp() override {
126     ABinderProcess_setThreadPoolMaxThreadCount(1);
127     ABinderProcess_startThreadPool();
128 
129     WakelockNative::Get().Initialize();
130 
131     auto binder_raw = AServiceManager_waitForService("suspend_control");
132     ASSERT_NE(binder_raw, nullptr);
133     binder.set(binder_raw);
134     control_service_ = ISuspendControlService::fromBinder(binder);
135     if (control_service_ == nullptr) {
136       FAIL() << "Fail to obtain suspend_control";
137     }
138 
139     if (suspend_callback == nullptr) {
140       suspend_callback = SharedRefBase::make<SuspendCallback>();
141       bool is_registered = false;
142       ScopedAStatus status = control_service_->registerCallback(suspend_callback, &is_registered);
143       if (!is_registered || !status.isOk()) {
144         FAIL() << "Fail to register suspend callback";
145       }
146     }
147 
148     if (control_callback == nullptr) {
149       control_callback = SharedRefBase::make<WakelockCallback>();
150       bool is_registered = false;
151       ScopedAStatus status = control_service_->registerWakelockCallback(
152               control_callback, kTestWakelockName, &is_registered);
153       if (!is_registered || !status.isOk()) {
154         FAIL() << "Fail to register wakeup callback";
155       }
156     }
157     control_callback->net_acquired_count = 0;
158   }
159 
TearDown()160   void TearDown() override {
161     control_service_ = nullptr;
162     binder.set(nullptr);
163     WakelockNative::Get().CleanUp();
164   }
165 
166   SpAIBinder binder;
167   std::shared_ptr<ISuspendControlService> control_service_ = nullptr;
168 };
169 
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks)170 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks) {
171   ASSERT_EQ(control_callback->net_acquired_count, 0);
172 
173   {
174     PromiseFutureContext context(acquire_promise, true);
175     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
176     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
177   }
178   ASSERT_EQ(control_callback->net_acquired_count, 1);
179 
180   {
181     PromiseFutureContext context(release_promise, true);
182     auto status = WakelockNative::Get().Release(kTestWakelockName);
183     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
184   }
185   ASSERT_EQ(control_callback->net_acquired_count, 0);
186 }
187 
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_acquire)188 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_acquire) {
189   ASSERT_EQ(control_callback->net_acquired_count, 0);
190 
191   {
192     PromiseFutureContext context(acquire_promise, true);
193     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
194     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
195   }
196   ASSERT_EQ(control_callback->net_acquired_count, 1);
197 
198   {
199     PromiseFutureContext context(acquire_promise, false);
200     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
201     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
202   }
203   ASSERT_EQ(control_callback->net_acquired_count, 1);
204 
205   {
206     PromiseFutureContext context(release_promise, true);
207     auto status = WakelockNative::Get().Release(kTestWakelockName);
208     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
209   }
210   ASSERT_EQ(control_callback->net_acquired_count, 0);
211 }
212 
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_release)213 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_release) {
214   ASSERT_EQ(control_callback->net_acquired_count, 0);
215 
216   {
217     PromiseFutureContext context(acquire_promise, true);
218     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
219     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
220   }
221   ASSERT_EQ(control_callback->net_acquired_count, 1);
222 
223   {
224     PromiseFutureContext context(release_promise, true);
225     auto status = WakelockNative::Get().Release(kTestWakelockName);
226     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
227   }
228   ASSERT_EQ(control_callback->net_acquired_count, 0);
229 
230   {
231     PromiseFutureContext context(release_promise, false);
232     auto status = WakelockNative::Get().Release(kTestWakelockName);
233     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
234   }
235   ASSERT_EQ(control_callback->net_acquired_count, 0);
236 }
237 
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_in_a_loop)238 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_in_a_loop) {
239   ASSERT_EQ(control_callback->net_acquired_count, 0);
240 
241   for (int i = 0; i < 10; ++i) {
242     {
243       PromiseFutureContext context(acquire_promise, true);
244       auto status = WakelockNative::Get().Acquire(kTestWakelockName);
245       ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
246     }
247     ASSERT_EQ(control_callback->net_acquired_count, 1);
248 
249     {
250       PromiseFutureContext context(release_promise, true);
251       auto status = WakelockNative::Get().Release(kTestWakelockName);
252       ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
253     }
254     ASSERT_EQ(control_callback->net_acquired_count, 0);
255   }
256 }
257 
TEST_F(WakelockNativeTest,test_clean_up)258 TEST_F(WakelockNativeTest, test_clean_up) {
259   WakelockNative::Get().Initialize();
260   ASSERT_EQ(control_callback->net_acquired_count, 0);
261 
262   {
263     PromiseFutureContext context(acquire_promise, true);
264     auto status = WakelockNative::Get().Acquire(kTestWakelockName);
265     ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
266   }
267   ASSERT_EQ(control_callback->net_acquired_count, 1);
268 
269   {
270     PromiseFutureContext context(release_promise, true);
271     WakelockNative::Get().CleanUp();
272   }
273   ASSERT_EQ(control_callback->net_acquired_count, 0);
274 }
275 
276 }  // namespace testing
277