1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_async2/dispatcher.h"
16
17 #include "gtest/gtest.h"
18 #include "pw_containers/vector.h"
19
20 namespace pw::async2 {
21 namespace {
22
23 class MockTask : public Task {
24 public:
25 bool should_complete = false;
26 int polled = 0;
27 int destroyed = 0;
28 Waker last_waker;
29
30 private:
DoPend(Context & cx)31 Poll<> DoPend(Context& cx) override {
32 ++polled;
33 PW_ASYNC_STORE_WAKER(cx, last_waker, "MockTask is waiting for last_waker");
34 if (should_complete) {
35 return Ready();
36 } else {
37 return Pending();
38 }
39 }
DoDestroy()40 void DoDestroy() override { ++destroyed; }
41 };
42
43 class MockPendable {
44 public:
MockPendable(Poll<int> value)45 MockPendable(Poll<int> value) : value_(value) {}
Pend(Context &)46 Poll<int> Pend(Context&) { return value_; }
47
48 private:
49 Poll<int> value_;
50 };
51
TEST(Dispatcher,RunUntilStalledPendsPostedTask)52 TEST(Dispatcher, RunUntilStalledPendsPostedTask) {
53 MockTask task;
54 task.should_complete = true;
55 Dispatcher dispatcher;
56 dispatcher.Post(task);
57 EXPECT_TRUE(task.IsRegistered());
58 EXPECT_TRUE(dispatcher.RunUntilStalled(task).IsReady());
59 EXPECT_EQ(task.polled, 1);
60 EXPECT_EQ(task.destroyed, 1);
61 EXPECT_FALSE(task.IsRegistered());
62 }
63
TEST(Dispatcher,RunUntilStalledReturnsOnNotReady)64 TEST(Dispatcher, RunUntilStalledReturnsOnNotReady) {
65 MockTask task;
66 task.should_complete = false;
67 Dispatcher dispatcher;
68 dispatcher.Post(task);
69 EXPECT_FALSE(dispatcher.RunUntilStalled(task).IsReady());
70 EXPECT_EQ(task.polled, 1);
71 EXPECT_EQ(task.destroyed, 0);
72 }
73
TEST(Dispatcher,RunUntilStalledDoesNotPendSleepingTask)74 TEST(Dispatcher, RunUntilStalledDoesNotPendSleepingTask) {
75 MockTask task;
76 task.should_complete = false;
77 Dispatcher dispatcher;
78 dispatcher.Post(task);
79
80 EXPECT_FALSE(dispatcher.RunUntilStalled(task).IsReady());
81 EXPECT_EQ(task.polled, 1);
82 EXPECT_EQ(task.destroyed, 0);
83
84 task.should_complete = true;
85 EXPECT_FALSE(dispatcher.RunUntilStalled(task).IsReady());
86 EXPECT_EQ(task.polled, 1);
87 EXPECT_EQ(task.destroyed, 0);
88
89 std::move(task.last_waker).Wake();
90 EXPECT_TRUE(dispatcher.RunUntilStalled(task).IsReady());
91 EXPECT_EQ(task.polled, 2);
92 EXPECT_EQ(task.destroyed, 1);
93 }
94
TEST(Dispatcher,RunUntilStalledWithNoTasksReturnsReady)95 TEST(Dispatcher, RunUntilStalledWithNoTasksReturnsReady) {
96 Dispatcher dispatcher;
97 EXPECT_TRUE(dispatcher.RunUntilStalled().IsReady());
98 }
99
TEST(Dispatcher,RunToCompletionPendsMultipleTasks)100 TEST(Dispatcher, RunToCompletionPendsMultipleTasks) {
101 class CounterTask : public Task {
102 public:
103 CounterTask(pw::span<Waker> wakers,
104 size_t this_waker_i,
105 int* counter,
106 int until)
107 : counter_(counter),
108 this_waker_i_(this_waker_i),
109 until_(until),
110 wakers_(wakers) {}
111 int* counter_;
112 size_t this_waker_i_;
113 int until_;
114 pw::span<Waker> wakers_;
115
116 private:
117 Poll<> DoPend(Context& cx) override {
118 ++(*counter_);
119 if (*counter_ >= until_) {
120 for (auto& waker : wakers_) {
121 std::move(waker).Wake();
122 }
123 return Ready();
124 } else {
125 PW_ASYNC_STORE_WAKER(cx,
126 wakers_[this_waker_i_],
127 "CounterTask is waiting for counter_ >= until_");
128 return Pending();
129 }
130 }
131 };
132
133 int counter = 0;
134 constexpr const int kNumTasks = 3;
135 std::array<Waker, kNumTasks> wakers;
136 CounterTask task_one(wakers, 0, &counter, kNumTasks);
137 CounterTask task_two(wakers, 1, &counter, kNumTasks);
138 CounterTask task_three(wakers, 2, &counter, kNumTasks);
139 Dispatcher dispatcher;
140 dispatcher.Post(task_one);
141 dispatcher.Post(task_two);
142 dispatcher.Post(task_three);
143 EXPECT_TRUE(dispatcher.RunUntilStalled().IsReady());
144 // We expect to see 5 total calls to `Pend`:
145 // - two which increment counter and return pending
146 // - one which increments the counter, returns complete, and wakes the
147 // others
148 // - two which have woken back up and complete
149 EXPECT_EQ(counter, 5);
150 }
151
TEST(Dispatcher,RunPendableUntilStalledReturnsOutputOnReady)152 TEST(Dispatcher, RunPendableUntilStalledReturnsOutputOnReady) {
153 MockPendable pollable(Ready(5));
154 Dispatcher dispatcher;
155 Poll<int> result = dispatcher.RunPendableUntilStalled(pollable);
156 EXPECT_EQ(result, Ready(5));
157 }
158
TEST(Dispatcher,RunPendableUntilStalledReturnsPending)159 TEST(Dispatcher, RunPendableUntilStalledReturnsPending) {
160 MockPendable pollable(Pending());
161 Dispatcher dispatcher;
162 Poll<int> result = dispatcher.RunPendableUntilStalled(pollable);
163 EXPECT_EQ(result, Pending());
164 }
165
TEST(Dispathcer,RunPendableToCompletionReturnsOutput)166 TEST(Dispathcer, RunPendableToCompletionReturnsOutput) {
167 MockPendable pollable(Ready(5));
168 Dispatcher dispatcher;
169 int result = dispatcher.RunPendableToCompletion(pollable);
170 EXPECT_EQ(result, 5);
171 }
172
TEST(Dispatcher,PostToDispatcherFromInsidePendSucceeds)173 TEST(Dispatcher, PostToDispatcherFromInsidePendSucceeds) {
174 class TaskPoster : public Task {
175 public:
176 TaskPoster(Task& task_to_post) : task_to_post_(&task_to_post) {}
177
178 private:
179 Poll<> DoPend(Context& cx) override {
180 cx.dispatcher().Post(*task_to_post_);
181 return Ready();
182 }
183 Task* task_to_post_;
184 };
185
186 MockTask posted_task;
187 posted_task.should_complete = true;
188 TaskPoster task_poster(posted_task);
189
190 Dispatcher dispatcher;
191 dispatcher.Post(task_poster);
192 EXPECT_TRUE(dispatcher.RunUntilStalled().IsReady());
193 EXPECT_EQ(posted_task.polled, 1);
194 EXPECT_EQ(posted_task.destroyed, 1);
195 }
196
TEST(Dispatcher,RunToCompletionPendsPostedTask)197 TEST(Dispatcher, RunToCompletionPendsPostedTask) {
198 MockTask task;
199 task.should_complete = true;
200 Dispatcher dispatcher;
201 dispatcher.Post(task);
202 dispatcher.RunToCompletion(task);
203 EXPECT_EQ(task.polled, 1);
204 EXPECT_EQ(task.destroyed, 1);
205 }
206
TEST(Dispatcher,RunToCompletionIgnoresDeregisteredTask)207 TEST(Dispatcher, RunToCompletionIgnoresDeregisteredTask) {
208 Dispatcher dispatcher;
209 MockTask task;
210 task.should_complete = false;
211 dispatcher.Post(task);
212 EXPECT_TRUE(task.IsRegistered());
213 task.Deregister();
214 EXPECT_FALSE(task.IsRegistered());
215 dispatcher.RunToCompletion();
216 EXPECT_EQ(task.polled, 0);
217 EXPECT_EQ(task.destroyed, 0);
218 }
219
220 } // namespace
221 } // namespace pw::async2
222