1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // UNSUPPORTED: no-threads
10 // UNSUPPORTED: libcpp-has-no-experimental-stop_token
11 // UNSUPPORTED: c++03, c++11, c++14, c++17
12 // XFAIL: availability-synchronization_library-missing
13 
14 // template<class C>
15 // explicit stop_callback(const stop_token& st, C&& cb)
16 //   noexcept(is_nothrow_constructible_v<Callback, C>);
17 
18 #include <atomic>
19 #include <cassert>
20 #include <chrono>
21 #include <stop_token>
22 #include <type_traits>
23 #include <utility>
24 #include <vector>
25 
26 #include "make_test_thread.h"
27 #include "test_macros.h"
28 
29 struct Cb {
30   void operator()() const;
31 };
32 
33 // Constraints: Callback and C satisfy constructible_from<Callback, C>.
34 static_assert(std::is_constructible_v<std::stop_callback<void (*)()>, const std::stop_token&, void (*)()>);
35 static_assert(!std::is_constructible_v<std::stop_callback<void (*)()>, const std::stop_token&, void (*)(int)>);
36 static_assert(std::is_constructible_v<std::stop_callback<Cb>, const std::stop_token&, Cb&>);
37 static_assert(std::is_constructible_v<std::stop_callback<Cb&>, const std::stop_token&, Cb&>);
38 static_assert(!std::is_constructible_v<std::stop_callback<Cb>, const std::stop_token&, int>);
39 
40 // explicit
41 template <class T>
42 void conversion_test(T);
43 
44 template <class T, class... Args>
45 concept ImplicitlyConstructible = requires(Args&&... args) { conversion_test<T>({std::forward<Args>(args)...}); };
46 static_assert(ImplicitlyConstructible<int, int>);
47 static_assert(!ImplicitlyConstructible<std::stop_callback<Cb>, const std::stop_token&, Cb>);
48 
49 // noexcept
50 template <bool NoExceptCtor>
51 struct CbNoExcept {
52   CbNoExcept(int) noexcept(NoExceptCtor);
53   void operator()() const;
54 };
55 static_assert(std::is_nothrow_constructible_v<std::stop_callback<CbNoExcept<true>>, const std::stop_token&, int>);
56 static_assert(!std::is_nothrow_constructible_v<std::stop_callback<CbNoExcept<false>>, const std::stop_token&, int>);
57 
main(int,char **)58 int main(int, char**) {
59   // was requested
60   {
61     std::stop_source ss;
62     const std::stop_token st = ss.get_token();
63     ss.request_stop();
64 
65     bool called = false;
66     std::stop_callback sc(st, [&] { called = true; });
67     assert(called);
68   }
69 
70   // was not requested
71   {
72     std::stop_source ss;
73     const std::stop_token st = ss.get_token();
74 
75     bool called = false;
76     std::stop_callback sc(st, [&] { called = true; });
77     assert(!called);
78 
79     ss.request_stop();
80     assert(called);
81   }
82 
83   // token has no state
84   {
85     std::stop_token st;
86     bool called = false;
87     std::stop_callback sc(st, [&] { called = true; });
88     assert(!called);
89   }
90 
91   // should not be called multiple times
92   {
93     std::stop_source ss;
94     const std::stop_token st = ss.get_token();
95 
96     int calledTimes = 0;
97     std::stop_callback sc(st, [&] { ++calledTimes; });
98 
99     std::vector<std::thread> threads;
100     for (auto i = 0; i < 10; ++i) {
101       threads.emplace_back(support::make_test_thread([&] { ss.request_stop(); }));
102     }
103 
104     for (auto& thread : threads) {
105       thread.join();
106     }
107     assert(calledTimes == 1);
108   }
109 
110   // adding more callbacks during invoking other callbacks
111   {
112     std::stop_source ss;
113     const std::stop_token st = ss.get_token();
114 
115     std::atomic<bool> startedFlag = false;
116     std::atomic<bool> finishFlag  = false;
117     std::stop_callback sc(st, [&] {
118       startedFlag = true;
119       startedFlag.notify_all();
120       finishFlag.wait(false);
121     });
122 
123     auto thread = support::make_test_thread([&] { ss.request_stop(); });
124 
125     startedFlag.wait(false);
126 
127     // first callback is still running, adding another one;
128     bool secondCallbackCalled = false;
129     std::stop_callback sc2(st, [&] { secondCallbackCalled = true; });
130 
131     finishFlag = true;
132     finishFlag.notify_all();
133 
134     thread.join();
135     assert(secondCallbackCalled);
136   }
137 
138   // adding callbacks on different threads
139   {
140     std::stop_source ss;
141     const std::stop_token st = ss.get_token();
142 
143     std::vector<std::thread> threads;
144     std::atomic<int> callbackCalledTimes = 0;
145     std::atomic<bool> done               = false;
146     for (auto i = 0; i < 10; ++i) {
147       threads.emplace_back(support::make_test_thread([&] {
148         std::stop_callback sc{st, [&] { callbackCalledTimes.fetch_add(1, std::memory_order_relaxed); }};
149         done.wait(false);
150       }));
151     }
152     using namespace std::chrono_literals;
153     std::this_thread::sleep_for(1ms);
154     ss.request_stop();
155     done = true;
156     done.notify_all();
157     for (auto& thread : threads) {
158       thread.join();
159     }
160     assert(callbackCalledTimes.load(std::memory_order_relaxed) == 10);
161   }
162 
163   // correct overload
164   {
165     struct CBWithTracking {
166       bool& lvalueCalled;
167       bool& lvalueConstCalled;
168       bool& rvalueCalled;
169       bool& rvalueConstCalled;
170 
171       void operator()() & { lvalueCalled = true; }
172       void operator()() const& { lvalueConstCalled = true; }
173       void operator()() && { rvalueCalled = true; }
174       void operator()() const&& { rvalueConstCalled = true; }
175     };
176 
177     // RValue
178     {
179       bool lvalueCalled      = false;
180       bool lvalueConstCalled = false;
181       bool rvalueCalled      = false;
182       bool rvalueConstCalled = false;
183       std::stop_source ss;
184       const std::stop_token st = ss.get_token();
185       ss.request_stop();
186 
187       std::stop_callback<CBWithTracking> sc(
188           st, CBWithTracking{lvalueCalled, lvalueConstCalled, rvalueCalled, rvalueConstCalled});
189       assert(rvalueCalled);
190     }
191 
192     // RValue
193     {
194       bool lvalueCalled      = false;
195       bool lvalueConstCalled = false;
196       bool rvalueCalled      = false;
197       bool rvalueConstCalled = false;
198       std::stop_source ss;
199       const std::stop_token st = ss.get_token();
200       ss.request_stop();
201 
202       std::stop_callback<const CBWithTracking> sc(
203           st, CBWithTracking{lvalueCalled, lvalueConstCalled, rvalueCalled, rvalueConstCalled});
204       assert(rvalueConstCalled);
205     }
206 
207     // LValue
208     {
209       bool lvalueCalled      = false;
210       bool lvalueConstCalled = false;
211       bool rvalueCalled      = false;
212       bool rvalueConstCalled = false;
213       std::stop_source ss;
214       const std::stop_token st = ss.get_token();
215       ss.request_stop();
216       CBWithTracking cb{lvalueCalled, lvalueConstCalled, rvalueCalled, rvalueConstCalled};
217       std::stop_callback<CBWithTracking&> sc(st, cb);
218       assert(lvalueCalled);
219     }
220 
221     // const LValue
222     {
223       bool lvalueCalled      = false;
224       bool lvalueConstCalled = false;
225       bool rvalueCalled      = false;
226       bool rvalueConstCalled = false;
227       std::stop_source ss;
228       const std::stop_token st = ss.get_token();
229       ss.request_stop();
230       CBWithTracking cb{lvalueCalled, lvalueConstCalled, rvalueCalled, rvalueConstCalled};
231       std::stop_callback<const CBWithTracking&> sc(st, cb);
232       assert(lvalueConstCalled);
233     }
234   }
235 
236   return 0;
237 }
238