1 /******************************************************************************
2 *
3 * Copyright 2020 Google, Inc.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at:
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 ******************************************************************************/
18
19 #include "os/internal/wakelock_native.h"
20
21 #include <aidl/android/system/suspend/BnSuspendCallback.h>
22 #include <aidl/android/system/suspend/BnWakelockCallback.h>
23 #include <aidl/android/system/suspend/ISuspendControlService.h>
24 #include <android/binder_auto_utils.h>
25 #include <android/binder_interface_utils.h>
26 #include <android/binder_manager.h>
27 #include <android/binder_process.h>
28 #include <gtest/gtest.h>
29
30 #include <chrono>
31 #include <future>
32 #include <memory>
33 #include <mutex>
34
35 namespace testing {
36
37 using aidl::android::system::suspend::BnSuspendCallback;
38 using aidl::android::system::suspend::BnWakelockCallback;
39 using aidl::android::system::suspend::ISuspendControlService;
40 using bluetooth::os::internal::WakelockNative;
41 using ndk::ScopedAStatus;
42 using ndk::SharedRefBase;
43 using ndk::SpAIBinder;
44
45 static const std::string kTestWakelockName = "BtWakelockNativeTestLock";
46
47 static std::recursive_mutex mutex;
48 static std::unique_ptr<std::promise<void>> acquire_promise = nullptr;
49 static std::unique_ptr<std::promise<void>> release_promise = nullptr;
50
51 class PromiseFutureContext {
52 public:
FulfilPromise(std::unique_ptr<std::promise<void>> & promise)53 static void FulfilPromise(std::unique_ptr<std::promise<void>>& promise) {
54 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
55 if (promise != nullptr) {
56 std::promise<void>* prom = promise.release();
57 prom->set_value();
58 delete prom;
59 }
60 }
61
PromiseFutureContext(std::unique_ptr<std::promise<void>> & promise,bool expect_fulfillment)62 explicit PromiseFutureContext(std::unique_ptr<std::promise<void>>& promise,
63 bool expect_fulfillment)
64 : promise_(promise), expect_fulfillment_(expect_fulfillment) {
65 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
66 EXPECT_EQ(promise_, nullptr);
67 promise_ = std::make_unique<std::promise<void>>();
68 future_ = promise->get_future();
69 }
70
~PromiseFutureContext()71 ~PromiseFutureContext() {
72 auto future_status = future_.wait_for(std::chrono::seconds(2));
73 if (expect_fulfillment_) {
74 EXPECT_EQ(future_status, std::future_status::ready);
75 } else {
76 EXPECT_NE(future_status, std::future_status::ready);
77 }
78 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
79 promise_ = nullptr;
80 }
81
82 private:
83 std::unique_ptr<std::promise<void>>& promise_;
84 bool expect_fulfillment_ = true;
85 std::future<void> future_;
86 };
87
88 class WakelockCallback : public BnWakelockCallback {
89 public:
notifyAcquired()90 ScopedAStatus notifyAcquired() override {
91 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
92 net_acquired_count++;
93 fprintf(stderr, "notifyAcquired, count = %d\n", net_acquired_count);
94 PromiseFutureContext::FulfilPromise(acquire_promise);
95 return ScopedAStatus::ok();
96 }
notifyReleased()97 ScopedAStatus notifyReleased() override {
98 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
99 net_acquired_count--;
100 fprintf(stderr, "notifyReleased, count = %d\n", net_acquired_count);
101 PromiseFutureContext::FulfilPromise(release_promise);
102 return ScopedAStatus::ok();
103 }
104
105 int net_acquired_count = 0;
106 };
107
108 class SuspendCallback : public BnSuspendCallback {
109 public:
notifyWakeup(bool,const std::vector<std::string> &)110 ScopedAStatus notifyWakeup(bool /* success */,
111 const std::vector<std::string>& /* wakeup_reasons */) override {
112 std::lock_guard<std::recursive_mutex> lock_guard(mutex);
113 fprintf(stderr, "notifyWakeup\n");
114 return ScopedAStatus::ok();
115 }
116 };
117
118 // There is no way to unregister these callbacks besides when this process dies
119 // Hence, we want to have only one copy of these callbacks per process
120 static std::shared_ptr<SuspendCallback> suspend_callback = nullptr;
121 static std::shared_ptr<WakelockCallback> control_callback = nullptr;
122
123 class WakelockNativeTest : public Test {
124 protected:
SetUp()125 void SetUp() override {
126 ABinderProcess_setThreadPoolMaxThreadCount(1);
127 ABinderProcess_startThreadPool();
128
129 WakelockNative::Get().Initialize();
130
131 auto binder_raw = AServiceManager_waitForService("suspend_control");
132 ASSERT_NE(binder_raw, nullptr);
133 binder.set(binder_raw);
134 control_service_ = ISuspendControlService::fromBinder(binder);
135 if (control_service_ == nullptr) {
136 FAIL() << "Fail to obtain suspend_control";
137 }
138
139 if (suspend_callback == nullptr) {
140 suspend_callback = SharedRefBase::make<SuspendCallback>();
141 bool is_registered = false;
142 ScopedAStatus status = control_service_->registerCallback(suspend_callback, &is_registered);
143 if (!is_registered || !status.isOk()) {
144 FAIL() << "Fail to register suspend callback";
145 }
146 }
147
148 if (control_callback == nullptr) {
149 control_callback = SharedRefBase::make<WakelockCallback>();
150 bool is_registered = false;
151 ScopedAStatus status = control_service_->registerWakelockCallback(
152 control_callback, kTestWakelockName, &is_registered);
153 if (!is_registered || !status.isOk()) {
154 FAIL() << "Fail to register wakeup callback";
155 }
156 }
157 control_callback->net_acquired_count = 0;
158 }
159
TearDown()160 void TearDown() override {
161 control_service_ = nullptr;
162 binder.set(nullptr);
163 WakelockNative::Get().CleanUp();
164 }
165
166 SpAIBinder binder;
167 std::shared_ptr<ISuspendControlService> control_service_ = nullptr;
168 };
169
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks)170 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks) {
171 ASSERT_EQ(control_callback->net_acquired_count, 0);
172
173 {
174 PromiseFutureContext context(acquire_promise, true);
175 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
176 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
177 }
178 ASSERT_EQ(control_callback->net_acquired_count, 1);
179
180 {
181 PromiseFutureContext context(release_promise, true);
182 auto status = WakelockNative::Get().Release(kTestWakelockName);
183 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
184 }
185 ASSERT_EQ(control_callback->net_acquired_count, 0);
186 }
187
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_acquire)188 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_acquire) {
189 ASSERT_EQ(control_callback->net_acquired_count, 0);
190
191 {
192 PromiseFutureContext context(acquire_promise, true);
193 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
194 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
195 }
196 ASSERT_EQ(control_callback->net_acquired_count, 1);
197
198 {
199 PromiseFutureContext context(acquire_promise, false);
200 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
201 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
202 }
203 ASSERT_EQ(control_callback->net_acquired_count, 1);
204
205 {
206 PromiseFutureContext context(release_promise, true);
207 auto status = WakelockNative::Get().Release(kTestWakelockName);
208 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
209 }
210 ASSERT_EQ(control_callback->net_acquired_count, 0);
211 }
212
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_repeated_release)213 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_repeated_release) {
214 ASSERT_EQ(control_callback->net_acquired_count, 0);
215
216 {
217 PromiseFutureContext context(acquire_promise, true);
218 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
219 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
220 }
221 ASSERT_EQ(control_callback->net_acquired_count, 1);
222
223 {
224 PromiseFutureContext context(release_promise, true);
225 auto status = WakelockNative::Get().Release(kTestWakelockName);
226 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
227 }
228 ASSERT_EQ(control_callback->net_acquired_count, 0);
229
230 {
231 PromiseFutureContext context(release_promise, false);
232 auto status = WakelockNative::Get().Release(kTestWakelockName);
233 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
234 }
235 ASSERT_EQ(control_callback->net_acquired_count, 0);
236 }
237
TEST_F(WakelockNativeTest,test_acquire_and_release_wakelocks_in_a_loop)238 TEST_F(WakelockNativeTest, test_acquire_and_release_wakelocks_in_a_loop) {
239 ASSERT_EQ(control_callback->net_acquired_count, 0);
240
241 for (int i = 0; i < 10; ++i) {
242 {
243 PromiseFutureContext context(acquire_promise, true);
244 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
245 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
246 }
247 ASSERT_EQ(control_callback->net_acquired_count, 1);
248
249 {
250 PromiseFutureContext context(release_promise, true);
251 auto status = WakelockNative::Get().Release(kTestWakelockName);
252 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
253 }
254 ASSERT_EQ(control_callback->net_acquired_count, 0);
255 }
256 }
257
TEST_F(WakelockNativeTest,test_clean_up)258 TEST_F(WakelockNativeTest, test_clean_up) {
259 WakelockNative::Get().Initialize();
260 ASSERT_EQ(control_callback->net_acquired_count, 0);
261
262 {
263 PromiseFutureContext context(acquire_promise, true);
264 auto status = WakelockNative::Get().Acquire(kTestWakelockName);
265 ASSERT_EQ(status, WakelockNative::StatusCode::SUCCESS);
266 }
267 ASSERT_EQ(control_callback->net_acquired_count, 1);
268
269 {
270 PromiseFutureContext context(release_promise, true);
271 WakelockNative::Get().CleanUp();
272 }
273 ASSERT_EQ(control_callback->net_acquired_count, 0);
274 }
275
276 } // namespace testing
277